Certified Training with CTRAIN¶
In this example, we train the standard CNN7 Architecture proposed by Shi et al. on the MNIST dataset using CTRAIN. We want to train for certifiable robustness against perturbations in the $l_\infty$ norm ball around inputs with radius $\epsilon=0.1$. For that, we utilise IBP training with the improvements by Shi et al.
First, we load the torch
library as well as the required functions from the CTRAIN
library.
import torch
from CTRAIN.model_definitions import CNN7_Shi
from CTRAIN.model_wrappers import ShiIBPModelWrapper
from CTRAIN.data_loaders import load_mnist
Now, we load the MNIST dataset using CTRAIN
and define the model.
in_shape = [1, 28, 28]
train_loader, test_loader = load_mnist(batch_size=128, val_split=False, data_root="../../data")
model = CNN7_Shi(in_shape=in_shape, n_classes=10)
MNIST dataset - Min value: -0.4242129623889923, Max value: 2.821486711502075
To train the network certifiably, we have to wrap it in a CTRAIN
model wrapper. If you desire to use another certified training method, please import the respective wrapper from the CTRAIN.model_wrappers
package. We initialise the wrapper with the required arguments of the training process, such as the number of warm up epochs, i.e. the number of epochs where the model is trained on natural loss, or the number of ramp up epochs, i.e. the number of epochs where the epsilon value is gradually increased to the final training epsilon. Please consult the documentation to set the other hyperparameters.
wrapped_model = ShiIBPModelWrapper(
model,
input_shape=in_shape,
eps=0.1,
num_epochs=70,
warm_up_epochs=0,
ramp_up_epochs=40,
lr_decay_milestones=(50, 60),
)
We initiate the training process by calling the train_model
function of the wrapped model.
wrapped_model.train_model(train_loader)
Finally, we save the resulting model weights.
torch.save(wrapped_model.state_dict(), '../../mnist_0.1_model.pt')