Hyperparameter Optimisation using CTRAIN¶
CTRAIN's standard hpo call runs multi-objective Optuna HPO and returns a Pareto front over natural and certified validation accuracy. This example shows that workflow on MNIST.
First, we import the necessary torch library and CTRAIN functions.
import torch
from CTRAIN.model_definitions import CNN7_Shi
from CTRAIN.model_wrappers import ShiIBPModelWrapper
from CTRAIN.data_loaders import load_mnist
Adding complete_verifier to sys.path
Thereafter, we load the MNIST dataset and define the neural network.
in_shape = [1, 28, 28]
train_loader, test_loader = load_mnist(batch_size=128, val_split=False, data_root="../../data")
model = CNN7_Shi(in_shape=in_shape, n_classes=10)
MNIST dataset - Min value: -0.4242129623889923, Max value: 2.821486711502075
To perform HPO, we wrap the network with one of the CTRAIN model wrappers. Here, we choose the Shi IBP wrapper. We keep the epoch count small so the example can be run quickly; publication-scale runs should use the budgets documented under papers/rethinking_evaluation_paradigms.
wrapped_model = ShiIBPModelWrapper(
model,
input_shape=in_shape,
eps=0.1,
num_epochs=14,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
The multi-objective interface optimises natural and certified validation accuracy jointly and stores all trials in an Optuna study. For a lightweight local run, we use Optuna's NSGA-II sampler and a few trials. For paper-style runs, use sampler="botorch", larger budgets, and repeat the optimisation for multiple seeds.
pareto_front = wrapped_model.hpo(
train_loader=train_loader,
val_loader=test_loader,
defaults={
'warm_up_epochs': 0,
'ramp_up_epochs': 10,
'lr_decay_factor': 0.2,
'lr_decay_epoch_1': 1,
'lr_decay_epoch_2': 1,
'l1_reg_weight': 1e-6,
'shi_reg_weight': 1,
'shi_reg_decay': True,
'train_eps_factor': 1,
'optimizer_func': 'adam',
'lr': 5e-4,
'start_kappa': 1,
'end_kappa': 0,
},
output_dir="./optuna/shi_mnist_0.1/",
budget_trials=3,
eval_samples=1_000,
sampler="nsgaii",
)
pareto_front
Finally, we select one Pareto-optimal checkpoint, load it, save it, and evaluate it.
selected_trial = pareto_front[0]
wrapped_model.load_state_dict(torch.load(selected_trial['checkpoint_path'], map_location=wrapped_model.device))
torch.save(wrapped_model.state_dict(), './shi_mnist_0.1_mo_hpo.pt')
wrapped_model.eval()
std_acc, cert_acc, adv_acc = wrapped_model.evaluate(test_loader=test_loader, test_samples=1_000)
print(f"Std Acc: {std_acc}, Cert. Acc: {cert_acc}, Adv. Acc: {adv_acc}")