Certified Training with CTRAIN¶
In this example, we train the standard CNN7 Architecture proposed by Shi et al. on the MNIST dataset using CTRAIN. We want to train for certifiable robustness against perturbations in the $l_\infty$ norm ball around inputs with radius $\epsilon=0.1$. For that, we utilise IBP training with the improvements by Shi et al.
First, we load the torch library as well as the required functions from the CTRAIN library.
import torch
from CTRAIN.model_definitions import CNN7_Shi
from CTRAIN.model_wrappers import ShiIBPModelWrapper
from CTRAIN.data_loaders import load_mnist
Now, we load the MNIST dataset using CTRAIN and define the model.
in_shape = [1, 28, 28]
train_loader, test_loader = load_mnist(batch_size=128, val_split=False, data_root="../../data")
model = CNN7_Shi(in_shape=in_shape, n_classes=10)
MNIST dataset - Min value: -0.4242129623889923, Max value: 2.821486711502075
To train the network certifiably, we have to wrap it in a CTRAIN model wrapper. If you desire to use another certified training method, please import the respective wrapper from the CTRAIN.model_wrappers package. We initialise the wrapper with the required arguments of the training process, such as the number of warm up epochs, i.e. the number of epochs where the model is trained on natural loss, or the number of ramp up epochs, i.e. the number of epochs where the epsilon value is gradually increased to the final training epsilon. Please consult the documentation to set the other hyperparameters.
wrapped_model = ShiIBPModelWrapper(
model,
input_shape=in_shape,
eps=0.1,
num_epochs=70,
warm_up_epochs=0,
ramp_up_epochs=40,
lr_decay_milestones=(50, 60),
)
We initiate the training process by calling the train_model function of the wrapped model.
wrapped_model.train_model(train_loader)
Finally, we save the resulting model weights.
torch.save(wrapped_model.state_dict(), '../../mnist_0.1_model.pt')