Hyperparameter Optimisation using CTRAIN¶
CTRAIN offers seamless integration of sophisticated hyperparameter optimisation using SMAC3.
First, we import the necessary torch
library and CTRAIN
functions
In [1]:
Copied!
import torch
from CTRAIN.model_definitions import CNN7_Shi
from CTRAIN.model_wrappers import ShiIBPModelWrapper
from CTRAIN.data_loaders import load_mnist
import torch
from CTRAIN.model_definitions import CNN7_Shi
from CTRAIN.model_wrappers import ShiIBPModelWrapper
from CTRAIN.data_loaders import load_mnist
Adding complete_verifier to sys.path
Thereafter, we load the MNIST dataset and define the neural network.
In [2]:
Copied!
in_shape = [1, 28, 28]
train_loader, test_loader = load_mnist(batch_size=128, val_split=False, data_root="../../data")
model = CNN7_Shi(in_shape=in_shape, n_classes=10)
in_shape = [1, 28, 28]
train_loader, test_loader = load_mnist(batch_size=128, val_split=False, data_root="../../data")
model = CNN7_Shi(in_shape=in_shape, n_classes=10)
MNIST dataset - Min value: -0.4242129623889923, Max value: 2.821486711502075
To perform HPO, we have to wrap the network around one of the model wrappers of CTRAIN
. Here, we choose the Shi IBP wrapper.
In [4]:
Copied!
wrapped_model = ShiIBPModelWrapper(
model,
input_shape=in_shape,
eps=0.1,
num_epochs=70
)
wrapped_model = ShiIBPModelWrapper(
model,
input_shape=in_shape,
eps=0.1,
num_epochs=70
)
Thereafter we perform the parameter tuning, while evaluating probed configurations on the test set. Furthermore, we provide sensible defaults to guide the optimisation. To save resources, we do not execute the HPO in this notebook.
In [ ]:
Copied!
wrapped_model.hpo(train_loader=train_loader, val_loader=test_loader, defaults={
'warm_up_epochs': 0,
'ramp_up_epochs': 50,
'lr_decay_factor': 0.2,
'lr_decay_epoch_1': 10, # added unto warm_up and ramp_up epochs
'lr_decay_epoch_2': 10, # added unto warm_up, ramp_up and lr_decay_1 epochs
'l1_reg_weight': 1e-06,
'shi_reg_weight': 1,
'shi_reg_decay': True,
'train_eps_factor': 1,
'optimizer_func': 'adam',
'learning_rate': 5e-04,
'start_kappa': 1,
'end_kappa': 0
}, output_dir='./smac/shi_mnist_0.1/')
wrapped_model.hpo(train_loader=train_loader, val_loader=test_loader, defaults={
'warm_up_epochs': 0,
'ramp_up_epochs': 50,
'lr_decay_factor': 0.2,
'lr_decay_epoch_1': 10, # added unto warm_up and ramp_up epochs
'lr_decay_epoch_2': 10, # added unto warm_up, ramp_up and lr_decay_1 epochs
'l1_reg_weight': 1e-06,
'shi_reg_weight': 1,
'shi_reg_decay': True,
'train_eps_factor': 1,
'optimizer_func': 'adam',
'learning_rate': 5e-04,
'start_kappa': 1,
'end_kappa': 0
}, output_dir='./smac/shi_mnist_0.1/')
Finally, we save the model trained on the optimal configuration and evaluate it.
In [ ]:
Copied!
torch.save(wrapped_model.state_dict(), './shi_incumbent_cifar10_2_255.pt')
wrapped_model.eval()
std_acc, cert_acc, adv_acc = wrapped_model.evaluate(test_loader=test_loader, test_samples=1_000)
print(f"Std Acc: {std_acc}, Cert. Acc: {cert_acc}, Adv. Acc: {adv_acc}")
torch.save(wrapped_model.state_dict(), './shi_incumbent_cifar10_2_255.pt')
wrapped_model.eval()
std_acc, cert_acc, adv_acc = wrapped_model.evaluate(test_loader=test_loader, test_samples=1_000)
print(f"Std Acc: {std_acc}, Cert. Acc: {cert_acc}, Adv. Acc: {adv_acc}")