Skip to content

SHI IBP

ShiIBPModelWrapper

Bases: CTRAINWrapper

Wrapper class for training models using SHI-IBP method. For details, see Shi et al. (2021) Fast certified robust training with short warmup. https://proceedings.neurips.cc/paper/2021/file/988f9153ac4fd966ea302dd9ab9bae15-Paper.pdf

Source code in CTRAIN/model_wrappers/shi_ibp_model_wrapper.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class ShiIBPModelWrapper(CTRAINWrapper):
    """
    Wrapper class for training models using SHI-IBP method. For details, see Shi et al. (2021) Fast certified robust training with short warmup. https://proceedings.neurips.cc/paper/2021/file/988f9153ac4fd966ea302dd9ab9bae15-Paper.pdf
    """

    def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
                 lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
                 shi_reg_weight=.5, shi_reg_decay=True, start_kappa=1, end_kappa=0, checkpoint_save_path=None, checkpoint_save_interval=10,
                 bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda')):
        """
        Initializes the ShiIBPModelWrapper.

        Args:
            model (torch.nn.Module): The model to be trained.
            input_shape (tuple): Shape of the input data.
            eps (float): Epsilon value describing the perturbation the network should be certifiably robust against.
            num_epochs (int): Number of epochs for training.
            train_eps_factor (float): Factor for training epsilon.
            optimizer_func (torch.optim.Optimizer): Optimizer function.
            lr (float): Learning rate.
            warm_up_epochs (int): Number of warm-up epochs, i.e. epochs where the model is trained on clean loss.
            ramp_up_epochs (int): Number of ramp-up epochs, i.e. epochs where the epsilon is gradually increased to the target train epsilon.
            lr_decay_factor (float): Learning rate decay factor.
            lr_decay_milestones (tuple): Milestones for learning rate decay.
            gradient_clip (float): Gradient clipping value.
            l1_reg_weight (float): L1 regularization weight.
            shi_reg_weight (float): SHI regularization weight.
            shi_reg_decay (bool): Whether to decay SHI regularization during the ramp up phase.
            start_kappa (float): Starting value of kappa that trades-off IBP and clean loss.
            end_kappa (float): Ending value of kappa.
            checkpoint_save_path (str): Path to save checkpoints.
            checkpoint_save_interval (int): Interval for saving checkpoints.
            bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
            device (torch.device): Device to run the training on.
        """
        super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
        self.cert_train_method = 'shi'
        self.num_epochs = num_epochs
        self.lr = lr
        self.warm_up_epochs = warm_up_epochs
        self.ramp_up_epochs = ramp_up_epochs
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_milestones = lr_decay_milestones
        self.gradient_clip = gradient_clip
        self.l1_reg_weight = l1_reg_weight
        self.shi_reg_weight = shi_reg_weight
        self.shi_reg_decay = shi_reg_decay
        self.start_kappa = start_kappa
        self.end_kappa = end_kappa
        self.optimizer_func = optimizer_func

    def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=None):
        """
        Trains the model using the SHI-IBP method.

        Args:
            train_loader (torch.utils.data.DataLoader): DataLoader for training data.
            val_loader (torch.utils.data.DataLoader, optional): DataLoader for validation data.
            start_epoch (int, optional): Epoch to start training from. Initialises learning rate and epsilon schedulers accordingly. Defaults to 0.
            end_epoch (int, optional): Epoch to prematurely end training at. Defaults to None.

        Returns:
            (auto_LiRPA.BoundedModule): Trained model.
        """
        eps_std = self.train_eps / train_loader.std if train_loader.normalised else torch.tensor(self.train_eps)
        eps_std = torch.reshape(eps_std, (*eps_std.shape, 1, 1))
        trained_model = shi_train_model(
            original_model=self.original_model,
            hardened_model=self.bounded_model,
            train_loader=train_loader,
            val_loader=val_loader,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            num_epochs=self.num_epochs,
            eps=self.train_eps,
            eps_std=eps_std,
            eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
            eps_scheduler_args={'start_kappa': self.start_kappa, 'end_kappa': self.end_kappa},
            optimizer=self.optimizer,
            lr_decay_schedule=self.lr_decay_milestones,
            lr_decay_factor=self.lr_decay_factor,
            n_classes=self.n_classes,
            gradient_clip=self.gradient_clip,
            l1_regularisation_weight=self.l1_reg_weight,
            shi_regularisation_weight=self.shi_reg_weight,
            shi_reg_decay=self.shi_reg_decay,
            results_path=self.checkpoint_path,
            checkpoint_save_interval=self.checkpoint_save_interval,
            device=self.device
        )

        return trained_model

    def _hpo_runner(self, config, seed, epochs, train_loader, val_loader, output_dir, cert_eval_samples=1000, include_nat_loss=True, include_adv_loss=True, include_cert_loss=True):
        """
        Function called during hyperparameter optimization (HPO) using SMAC3, returns the loss.

        Args:
            config (dict): Configuration of hyperparameters.
            seed (int): Seed used.
            epochs (int): Number of epochs for training.
            train_loader (torch.utils.data.DataLoader): DataLoader for training data.
            val_loader (torch.utils.data.DataLoader): DataLoader for validation data.
            output_dir (str): Directory to save output.
            cert_eval_samples (int, optional): Number of samples for certification evaluation.
            include_nat_loss (bool, optional): Whether to include natural loss into HPO loss.
            include_adv_loss (bool, optional): Whether to include adversarial loss into HPO loss.
            include_cert_loss (bool, optional): Whether to include certification loss into HPO loss.

        Returns:
            tuple: Loss and dictionary of accuracies that is saved as information to the run by SMAC3.
        """
        config_hash = get_config_hash(config, 32)
        seed_ctrain(seed)

        if config['optimizer_func'] == 'adam':
            optimizer_func = torch.optim.Adam
        elif config['optimizer_func'] == 'radam':
            optimizer_func = torch.optim.RAdam
        if config['optimizer_func'] == 'adamw':
            optimizer_func = torch.optim.AdamW

        lr_decay_milestones = [
            config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'],
            config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'] + config['lr_decay_epoch_2']
        ]

        model_wrapper = ShiIBPModelWrapper(
            model=copy.deepcopy(self.original_model), 
            input_shape=self.input_shape,
            eps=self.eps,
            num_epochs=epochs, 
            bound_opts=self.bound_opts,
            checkpoint_save_path=None,
            device=self.device,
            train_eps_factor=config['train_eps_factor'],
            optimizer_func=optimizer_func,
            lr=config['learning_rate'],
            warm_up_epochs=config['warm_up_epochs'],
            ramp_up_epochs=config['ramp_up_epochs'],
            gradient_clip=10,
            lr_decay_factor=config['lr_decay_factor'],
            lr_decay_milestones=[epoch for epoch in lr_decay_milestones if epoch <= epochs],
            l1_reg_weight=config['l1_reg_weight'],
            shi_reg_weight=config['shi_reg_weight'],
            shi_reg_decay=config['shi_reg_decay'],
            start_kappa=config['shi:start_kappa'],
            end_kappa=config['shi:end_kappa'] * config['shi:start_kappa'],
        )

        model_wrapper.train_model(train_loader=train_loader)
        torch.save(model_wrapper.state_dict(), f'{output_dir}/nets/{config_hash}.pt')
        model_wrapper.eval()
        std_acc, cert_acc, adv_acc = model_wrapper.evaluate(test_loader=val_loader, test_samples=cert_eval_samples)

        loss = 0
        if include_nat_loss:
            loss -= std_acc
        if include_adv_loss:
            loss -= adv_acc
        if include_cert_loss:
            loss -= cert_acc

        return loss, {'nat_acc': std_acc, 'adv_acc': adv_acc, 'cert_acc': cert_acc}

__init__(model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70, lr_decay_factor=0.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=1e-06, shi_reg_weight=0.5, shi_reg_decay=True, start_kappa=1, end_kappa=0, checkpoint_save_path=None, checkpoint_save_interval=10, bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda'))

Initializes the ShiIBPModelWrapper.

Parameters:

Name Type Description Default
model Module

The model to be trained.

required
input_shape tuple

Shape of the input data.

required
eps float

Epsilon value describing the perturbation the network should be certifiably robust against.

required
num_epochs int

Number of epochs for training.

required
train_eps_factor float

Factor for training epsilon.

1
optimizer_func Optimizer

Optimizer function.

Adam
lr float

Learning rate.

0.0005
warm_up_epochs int

Number of warm-up epochs, i.e. epochs where the model is trained on clean loss.

1
ramp_up_epochs int

Number of ramp-up epochs, i.e. epochs where the epsilon is gradually increased to the target train epsilon.

70
lr_decay_factor float

Learning rate decay factor.

0.2
lr_decay_milestones tuple

Milestones for learning rate decay.

(80, 90)
gradient_clip float

Gradient clipping value.

10
l1_reg_weight float

L1 regularization weight.

1e-06
shi_reg_weight float

SHI regularization weight.

0.5
shi_reg_decay bool

Whether to decay SHI regularization during the ramp up phase.

True
start_kappa float

Starting value of kappa that trades-off IBP and clean loss.

1
end_kappa float

Ending value of kappa.

0
checkpoint_save_path str

Path to save checkpoints.

None
checkpoint_save_interval int

Interval for saving checkpoints.

10
bound_opts dict

Options for bounding according to the auto_LiRPA documentation.

dict(conv_mode='patches', relu='adaptive')
device device

Device to run the training on.

device('cuda')
Source code in CTRAIN/model_wrappers/shi_ibp_model_wrapper.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
             lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
             shi_reg_weight=.5, shi_reg_decay=True, start_kappa=1, end_kappa=0, checkpoint_save_path=None, checkpoint_save_interval=10,
             bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda')):
    """
    Initializes the ShiIBPModelWrapper.

    Args:
        model (torch.nn.Module): The model to be trained.
        input_shape (tuple): Shape of the input data.
        eps (float): Epsilon value describing the perturbation the network should be certifiably robust against.
        num_epochs (int): Number of epochs for training.
        train_eps_factor (float): Factor for training epsilon.
        optimizer_func (torch.optim.Optimizer): Optimizer function.
        lr (float): Learning rate.
        warm_up_epochs (int): Number of warm-up epochs, i.e. epochs where the model is trained on clean loss.
        ramp_up_epochs (int): Number of ramp-up epochs, i.e. epochs where the epsilon is gradually increased to the target train epsilon.
        lr_decay_factor (float): Learning rate decay factor.
        lr_decay_milestones (tuple): Milestones for learning rate decay.
        gradient_clip (float): Gradient clipping value.
        l1_reg_weight (float): L1 regularization weight.
        shi_reg_weight (float): SHI regularization weight.
        shi_reg_decay (bool): Whether to decay SHI regularization during the ramp up phase.
        start_kappa (float): Starting value of kappa that trades-off IBP and clean loss.
        end_kappa (float): Ending value of kappa.
        checkpoint_save_path (str): Path to save checkpoints.
        checkpoint_save_interval (int): Interval for saving checkpoints.
        bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
        device (torch.device): Device to run the training on.
    """
    super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
    self.cert_train_method = 'shi'
    self.num_epochs = num_epochs
    self.lr = lr
    self.warm_up_epochs = warm_up_epochs
    self.ramp_up_epochs = ramp_up_epochs
    self.lr_decay_factor = lr_decay_factor
    self.lr_decay_milestones = lr_decay_milestones
    self.gradient_clip = gradient_clip
    self.l1_reg_weight = l1_reg_weight
    self.shi_reg_weight = shi_reg_weight
    self.shi_reg_decay = shi_reg_decay
    self.start_kappa = start_kappa
    self.end_kappa = end_kappa
    self.optimizer_func = optimizer_func

_hpo_runner(config, seed, epochs, train_loader, val_loader, output_dir, cert_eval_samples=1000, include_nat_loss=True, include_adv_loss=True, include_cert_loss=True)

Function called during hyperparameter optimization (HPO) using SMAC3, returns the loss.

Parameters:

Name Type Description Default
config dict

Configuration of hyperparameters.

required
seed int

Seed used.

required
epochs int

Number of epochs for training.

required
train_loader DataLoader

DataLoader for training data.

required
val_loader DataLoader

DataLoader for validation data.

required
output_dir str

Directory to save output.

required
cert_eval_samples int

Number of samples for certification evaluation.

1000
include_nat_loss bool

Whether to include natural loss into HPO loss.

True
include_adv_loss bool

Whether to include adversarial loss into HPO loss.

True
include_cert_loss bool

Whether to include certification loss into HPO loss.

True

Returns:

Name Type Description
tuple

Loss and dictionary of accuracies that is saved as information to the run by SMAC3.

Source code in CTRAIN/model_wrappers/shi_ibp_model_wrapper.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def _hpo_runner(self, config, seed, epochs, train_loader, val_loader, output_dir, cert_eval_samples=1000, include_nat_loss=True, include_adv_loss=True, include_cert_loss=True):
    """
    Function called during hyperparameter optimization (HPO) using SMAC3, returns the loss.

    Args:
        config (dict): Configuration of hyperparameters.
        seed (int): Seed used.
        epochs (int): Number of epochs for training.
        train_loader (torch.utils.data.DataLoader): DataLoader for training data.
        val_loader (torch.utils.data.DataLoader): DataLoader for validation data.
        output_dir (str): Directory to save output.
        cert_eval_samples (int, optional): Number of samples for certification evaluation.
        include_nat_loss (bool, optional): Whether to include natural loss into HPO loss.
        include_adv_loss (bool, optional): Whether to include adversarial loss into HPO loss.
        include_cert_loss (bool, optional): Whether to include certification loss into HPO loss.

    Returns:
        tuple: Loss and dictionary of accuracies that is saved as information to the run by SMAC3.
    """
    config_hash = get_config_hash(config, 32)
    seed_ctrain(seed)

    if config['optimizer_func'] == 'adam':
        optimizer_func = torch.optim.Adam
    elif config['optimizer_func'] == 'radam':
        optimizer_func = torch.optim.RAdam
    if config['optimizer_func'] == 'adamw':
        optimizer_func = torch.optim.AdamW

    lr_decay_milestones = [
        config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'],
        config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'] + config['lr_decay_epoch_2']
    ]

    model_wrapper = ShiIBPModelWrapper(
        model=copy.deepcopy(self.original_model), 
        input_shape=self.input_shape,
        eps=self.eps,
        num_epochs=epochs, 
        bound_opts=self.bound_opts,
        checkpoint_save_path=None,
        device=self.device,
        train_eps_factor=config['train_eps_factor'],
        optimizer_func=optimizer_func,
        lr=config['learning_rate'],
        warm_up_epochs=config['warm_up_epochs'],
        ramp_up_epochs=config['ramp_up_epochs'],
        gradient_clip=10,
        lr_decay_factor=config['lr_decay_factor'],
        lr_decay_milestones=[epoch for epoch in lr_decay_milestones if epoch <= epochs],
        l1_reg_weight=config['l1_reg_weight'],
        shi_reg_weight=config['shi_reg_weight'],
        shi_reg_decay=config['shi_reg_decay'],
        start_kappa=config['shi:start_kappa'],
        end_kappa=config['shi:end_kappa'] * config['shi:start_kappa'],
    )

    model_wrapper.train_model(train_loader=train_loader)
    torch.save(model_wrapper.state_dict(), f'{output_dir}/nets/{config_hash}.pt')
    model_wrapper.eval()
    std_acc, cert_acc, adv_acc = model_wrapper.evaluate(test_loader=val_loader, test_samples=cert_eval_samples)

    loss = 0
    if include_nat_loss:
        loss -= std_acc
    if include_adv_loss:
        loss -= adv_acc
    if include_cert_loss:
        loss -= cert_acc

    return loss, {'nat_acc': std_acc, 'adv_acc': adv_acc, 'cert_acc': cert_acc}

train_model(train_loader, val_loader=None, start_epoch=0, end_epoch=None)

Trains the model using the SHI-IBP method.

Parameters:

Name Type Description Default
train_loader DataLoader

DataLoader for training data.

required
val_loader DataLoader

DataLoader for validation data.

None
start_epoch int

Epoch to start training from. Initialises learning rate and epsilon schedulers accordingly. Defaults to 0.

0
end_epoch int

Epoch to prematurely end training at. Defaults to None.

None

Returns:

Type Description
BoundedModule

Trained model.

Source code in CTRAIN/model_wrappers/shi_ibp_model_wrapper.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=None):
    """
    Trains the model using the SHI-IBP method.

    Args:
        train_loader (torch.utils.data.DataLoader): DataLoader for training data.
        val_loader (torch.utils.data.DataLoader, optional): DataLoader for validation data.
        start_epoch (int, optional): Epoch to start training from. Initialises learning rate and epsilon schedulers accordingly. Defaults to 0.
        end_epoch (int, optional): Epoch to prematurely end training at. Defaults to None.

    Returns:
        (auto_LiRPA.BoundedModule): Trained model.
    """
    eps_std = self.train_eps / train_loader.std if train_loader.normalised else torch.tensor(self.train_eps)
    eps_std = torch.reshape(eps_std, (*eps_std.shape, 1, 1))
    trained_model = shi_train_model(
        original_model=self.original_model,
        hardened_model=self.bounded_model,
        train_loader=train_loader,
        val_loader=val_loader,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        num_epochs=self.num_epochs,
        eps=self.train_eps,
        eps_std=eps_std,
        eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
        eps_scheduler_args={'start_kappa': self.start_kappa, 'end_kappa': self.end_kappa},
        optimizer=self.optimizer,
        lr_decay_schedule=self.lr_decay_milestones,
        lr_decay_factor=self.lr_decay_factor,
        n_classes=self.n_classes,
        gradient_clip=self.gradient_clip,
        l1_regularisation_weight=self.l1_reg_weight,
        shi_regularisation_weight=self.shi_reg_weight,
        shi_reg_decay=self.shi_reg_decay,
        results_path=self.checkpoint_path,
        checkpoint_save_interval=self.checkpoint_save_interval,
        device=self.device
    )

    return trained_model