Skip to content

MTL IBP

MTLIBPModelWrapper

Bases: CTRAINWrapper

Wrapper class for training models using MTL-IBP method. For details, see De Palma et al. (2024) Expressive Losses for Verified Robustness via Convex Combinations. https://arxiv.org/pdf/2305.13991

Source code in CTRAIN/model_wrappers/mtl_ibp_model_wrapper.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class MTLIBPModelWrapper(CTRAINWrapper):
    """
    Wrapper class for training models using MTL-IBP method. For details, see De Palma et al. (2024) Expressive Losses for Verified Robustness via Convex Combinations. https://arxiv.org/pdf/2305.13991
    """

    def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
                 lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
                 shi_reg_weight=.5, shi_reg_decay=True, pgd_steps=1, 
                 pgd_alpha=10, pgd_restarts=1, pgd_early_stopping=False, pgd_alpha_decay_factor=.1,
                 pgd_decay_milestones=(), pgd_eps_factor=1, mtl_ibp_alpha=0.5, checkpoint_save_path=None, checkpoint_save_interval=10,
                 bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda')):
        """
        Initializes the MTLIBPModelWrapper.

        Args:
            model (torch.nn.Module): The model to be trained.
            input_shape (tuple): Shape of the input data.
            eps (float): Epsilon value describing the perturbation the network should be certifiably robust against.
            num_epochs (int): Number of epochs for training.
            train_eps_factor (float): Factor for training epsilon.
            optimizer_func (torch.optim.Optimizer): Optimizer function.
            lr (float): Learning rate.
            warm_up_epochs (int): Number of warm-up epochs, i.e. epochs where the model is trained on clean loss.
            ramp_up_epochs (int): Number of ramp-up epochs, i.e. epochs where the epsilon is gradually increased to the target train epsilon.
            lr_decay_factor (float): Learning rate decay factor.
            lr_decay_milestones (tuple): Milestones for learning rate decay.
            gradient_clip (float): Gradient clipping value.
            l1_reg_weight (float): L1 regularization weight.
            shi_reg_weight (float): Shi regularization weight.
            shi_reg_decay (bool): Whether to decay Shi regularization during the ramp up phase.
            pgd_steps (int): Number of PGD steps for adversrial loss computation.
            pgd_alpha (float): PGD step size for adversarial loss calculation.
            pgd_restarts (int): Number of PGD restarts for adversarial loss calculation.
            pgd_early_stopping (bool): Whether to use early stopping in PGD during adversarial loss calculation.
            pgd_alpha_decay_factor (float): PGD alpha decay factor.
            pgd_decay_milestones (tuple): Milestones for PGD alpha decay.
            pgd_eps_factor (float): Factor for PGD epsilon.
            mtl_ibp_alpha (float): Alpha value for MTL-IBP, i.e. the trade-off between certified and adversarial loss.
            checkpoint_save_path (str): Path to save checkpoints.
            checkpoint_save_interval (int): Interval for saving checkpoints.
            bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
            device (torch.device): Device to run the training on.
        """
        super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
        self.cert_train_method = 'mtl_ibp'
        self.num_epochs = num_epochs
        self.lr = lr
        self.warm_up_epochs = warm_up_epochs
        self.ramp_up_epochs = ramp_up_epochs
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_milestones = lr_decay_milestones
        self.gradient_clip = gradient_clip
        self.l1_reg_weight = l1_reg_weight
        self.shi_reg_weight = shi_reg_weight
        self.shi_reg_decay = shi_reg_decay
        self.optimizer_func = optimizer_func
        self.pgd_steps = pgd_steps
        self.pgd_alpha = pgd_alpha
        self.pgd_restarts = pgd_restarts
        self.pgd_early_stopping = pgd_early_stopping
        self.pgd_alpha_decay_factor = pgd_alpha_decay_factor
        self.pgd_decay_milestones = pgd_decay_milestones
        self.pgd_eps_factor = pgd_eps_factor
        self.mtl_ibp_alpha = mtl_ibp_alpha


    def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=None):
        """
        Trains the model using the MTL-IBP method.

        Args:
            train_loader (torch.utils.data.DataLoader): DataLoader for training data.
            val_loader (torch.utils.data.DataLoader, optional): DataLoader for validation data.
            start_epoch (int, optional): Epoch to start training from. Initialises learning rate and epsilon schedulers accordingly. Defaults to 0.
            end_epoch (int, optional): Epoch to prematurely end training at. Defaults to None.

        Returns:
            (auto_LiRPA.BoundedModule): Trained model.
        """
        eps_std = self.train_eps / train_loader.std if train_loader.normalised else torch.tensor(self.train_eps)
        eps_std = torch.reshape(eps_std, (*eps_std.shape, 1, 1))

        trained_model = mtl_ibp_train_model(
            original_model=self.original_model,
            hardened_model=self.bounded_model,
            train_loader=train_loader,
            val_loader=val_loader,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            num_epochs=self.num_epochs,
            eps=self.train_eps,
            eps_std=eps_std,
            eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
            optimizer=self.optimizer,
            lr_decay_schedule=self.lr_decay_milestones,
            lr_decay_factor=self.lr_decay_factor,
            n_classes=self.n_classes,
            gradient_clip=self.gradient_clip,
            l1_regularisation_weight=self.l1_reg_weight,
            shi_regularisation_weight=self.shi_reg_weight,
            shi_reg_decay=self.shi_reg_decay,
            alpha=self.mtl_ibp_alpha,
            pgd_n_steps=self.pgd_steps,
            pgd_step_size=self.pgd_alpha,
            pgd_restarts=self.pgd_restarts,
            pgd_eps_factor=self.pgd_eps_factor,
            pgd_early_stopping=self.pgd_early_stopping,
            pgd_decay_factor=self.pgd_alpha_decay_factor,
            pgd_decay_checkpoints=self.pgd_decay_milestones,
            results_path=self.checkpoint_path,
            checkpoint_save_interval=self.checkpoint_save_interval,
            device=self.device
        )

        return trained_model

    def _hpo_runner(self, config, seed, epochs, train_loader, val_loader, output_dir, cert_eval_samples=1000, include_nat_loss=True, include_adv_loss=True, include_cert_loss=True):
        """
        Function called during hyperparameter optimization (HPO) using SMAC3, returns the loss.

        Args:
            config (dict): Configuration of hyperparameters.
            seed (int): Seed used.
            epochs (int): Number of epochs for training.
            train_loader (torch.utils.data.DataLoader): DataLoader for training data.
            val_loader (torch.utils.data.DataLoader): DataLoader for validation data.
            output_dir (str): Directory to save output.
            cert_eval_samples (int, optional): Number of samples for certification evaluation.
            include_nat_loss (bool, optional): Whether to include natural loss into HPO loss.
            include_adv_loss (bool, optional): Whether to include adversarial loss into HPO loss.
            include_cert_loss (bool, optional): Whether to include certification loss into HPO loss.

        Returns:
            tuple: Loss and dictionary of accuracies that is saved as information to the run by SMAC3.
        """
        config_hash = get_config_hash(config, 32)
        seed_ctrain(seed)

        if config['optimizer_func'] == 'adam':
            optimizer_func = torch.optim.Adam
        elif config['optimizer_func'] == 'radam':
            optimizer_func = torch.optim.RAdam
        if config['optimizer_func'] == 'adamw':
            optimizer_func = torch.optim.AdamW

        lr_decay_milestones = [
            config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'],
            config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'] + config['lr_decay_epoch_2']
        ]

        model_wrapper = MTLIBPModelWrapper(
            model=copy.deepcopy(self.original_model), 
            input_shape=self.input_shape,
            eps=self.eps,
            num_epochs=epochs, 
            bound_opts=self.bound_opts,
            checkpoint_save_path=None,
            device=self.device,
            train_eps_factor=config['train_eps_factor'],
            optimizer_func=optimizer_func,
            lr=config['learning_rate'],
            warm_up_epochs=config['warm_up_epochs'],
            ramp_up_epochs=config['ramp_up_epochs'],
            gradient_clip=10,
            lr_decay_factor=config['lr_decay_factor'],
            lr_decay_milestones=[epoch for epoch in lr_decay_milestones if epoch <= epochs],
            l1_reg_weight=config['l1_reg_weight'],
            shi_reg_weight=config['shi_reg_weight'],
            shi_reg_decay=config['shi_reg_decay'],
            mtl_ibp_alpha=config['mtl_ibp:mtl_ibp_alpha'],
            pgd_alpha=config['mtl_ibp:pgd_alpha'],
            pgd_early_stopping=False,
            pgd_restarts=config['mtl_ibp:pgd_restarts'],
            pgd_steps=config['mtl_ibp:pgd_steps'],
            pgd_eps_factor=config['mtl_ibp:mtl_ibp_eps_factor'],
            pgd_decay_milestones=()
        )

        model_wrapper.train_model(train_loader=train_loader)
        torch.save(model_wrapper.state_dict(), f'{output_dir}/nets/{config_hash}.pt')
        model_wrapper.eval()
        std_acc, cert_acc, adv_acc = model_wrapper.evaluate(test_loader=val_loader, test_samples=cert_eval_samples)

        loss = 0
        if include_nat_loss:
            loss -= std_acc
        if include_adv_loss:
            loss -= adv_acc
        if include_cert_loss:
            loss -= cert_acc

        return loss, {'nat_acc': std_acc, 'adv_acc': adv_acc, 'cert_acc': cert_acc}

__init__(model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70, lr_decay_factor=0.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=1e-06, shi_reg_weight=0.5, shi_reg_decay=True, pgd_steps=1, pgd_alpha=10, pgd_restarts=1, pgd_early_stopping=False, pgd_alpha_decay_factor=0.1, pgd_decay_milestones=(), pgd_eps_factor=1, mtl_ibp_alpha=0.5, checkpoint_save_path=None, checkpoint_save_interval=10, bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda'))

Initializes the MTLIBPModelWrapper.

Parameters:

Name Type Description Default
model Module

The model to be trained.

required
input_shape tuple

Shape of the input data.

required
eps float

Epsilon value describing the perturbation the network should be certifiably robust against.

required
num_epochs int

Number of epochs for training.

required
train_eps_factor float

Factor for training epsilon.

1
optimizer_func Optimizer

Optimizer function.

Adam
lr float

Learning rate.

0.0005
warm_up_epochs int

Number of warm-up epochs, i.e. epochs where the model is trained on clean loss.

1
ramp_up_epochs int

Number of ramp-up epochs, i.e. epochs where the epsilon is gradually increased to the target train epsilon.

70
lr_decay_factor float

Learning rate decay factor.

0.2
lr_decay_milestones tuple

Milestones for learning rate decay.

(80, 90)
gradient_clip float

Gradient clipping value.

10
l1_reg_weight float

L1 regularization weight.

1e-06
shi_reg_weight float

Shi regularization weight.

0.5
shi_reg_decay bool

Whether to decay Shi regularization during the ramp up phase.

True
pgd_steps int

Number of PGD steps for adversrial loss computation.

1
pgd_alpha float

PGD step size for adversarial loss calculation.

10
pgd_restarts int

Number of PGD restarts for adversarial loss calculation.

1
pgd_early_stopping bool

Whether to use early stopping in PGD during adversarial loss calculation.

False
pgd_alpha_decay_factor float

PGD alpha decay factor.

0.1
pgd_decay_milestones tuple

Milestones for PGD alpha decay.

()
pgd_eps_factor float

Factor for PGD epsilon.

1
mtl_ibp_alpha float

Alpha value for MTL-IBP, i.e. the trade-off between certified and adversarial loss.

0.5
checkpoint_save_path str

Path to save checkpoints.

None
checkpoint_save_interval int

Interval for saving checkpoints.

10
bound_opts dict

Options for bounding according to the auto_LiRPA documentation.

dict(conv_mode='patches', relu='adaptive')
device device

Device to run the training on.

device('cuda')
Source code in CTRAIN/model_wrappers/mtl_ibp_model_wrapper.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
             lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
             shi_reg_weight=.5, shi_reg_decay=True, pgd_steps=1, 
             pgd_alpha=10, pgd_restarts=1, pgd_early_stopping=False, pgd_alpha_decay_factor=.1,
             pgd_decay_milestones=(), pgd_eps_factor=1, mtl_ibp_alpha=0.5, checkpoint_save_path=None, checkpoint_save_interval=10,
             bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda')):
    """
    Initializes the MTLIBPModelWrapper.

    Args:
        model (torch.nn.Module): The model to be trained.
        input_shape (tuple): Shape of the input data.
        eps (float): Epsilon value describing the perturbation the network should be certifiably robust against.
        num_epochs (int): Number of epochs for training.
        train_eps_factor (float): Factor for training epsilon.
        optimizer_func (torch.optim.Optimizer): Optimizer function.
        lr (float): Learning rate.
        warm_up_epochs (int): Number of warm-up epochs, i.e. epochs where the model is trained on clean loss.
        ramp_up_epochs (int): Number of ramp-up epochs, i.e. epochs where the epsilon is gradually increased to the target train epsilon.
        lr_decay_factor (float): Learning rate decay factor.
        lr_decay_milestones (tuple): Milestones for learning rate decay.
        gradient_clip (float): Gradient clipping value.
        l1_reg_weight (float): L1 regularization weight.
        shi_reg_weight (float): Shi regularization weight.
        shi_reg_decay (bool): Whether to decay Shi regularization during the ramp up phase.
        pgd_steps (int): Number of PGD steps for adversrial loss computation.
        pgd_alpha (float): PGD step size for adversarial loss calculation.
        pgd_restarts (int): Number of PGD restarts for adversarial loss calculation.
        pgd_early_stopping (bool): Whether to use early stopping in PGD during adversarial loss calculation.
        pgd_alpha_decay_factor (float): PGD alpha decay factor.
        pgd_decay_milestones (tuple): Milestones for PGD alpha decay.
        pgd_eps_factor (float): Factor for PGD epsilon.
        mtl_ibp_alpha (float): Alpha value for MTL-IBP, i.e. the trade-off between certified and adversarial loss.
        checkpoint_save_path (str): Path to save checkpoints.
        checkpoint_save_interval (int): Interval for saving checkpoints.
        bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
        device (torch.device): Device to run the training on.
    """
    super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
    self.cert_train_method = 'mtl_ibp'
    self.num_epochs = num_epochs
    self.lr = lr
    self.warm_up_epochs = warm_up_epochs
    self.ramp_up_epochs = ramp_up_epochs
    self.lr_decay_factor = lr_decay_factor
    self.lr_decay_milestones = lr_decay_milestones
    self.gradient_clip = gradient_clip
    self.l1_reg_weight = l1_reg_weight
    self.shi_reg_weight = shi_reg_weight
    self.shi_reg_decay = shi_reg_decay
    self.optimizer_func = optimizer_func
    self.pgd_steps = pgd_steps
    self.pgd_alpha = pgd_alpha
    self.pgd_restarts = pgd_restarts
    self.pgd_early_stopping = pgd_early_stopping
    self.pgd_alpha_decay_factor = pgd_alpha_decay_factor
    self.pgd_decay_milestones = pgd_decay_milestones
    self.pgd_eps_factor = pgd_eps_factor
    self.mtl_ibp_alpha = mtl_ibp_alpha

_hpo_runner(config, seed, epochs, train_loader, val_loader, output_dir, cert_eval_samples=1000, include_nat_loss=True, include_adv_loss=True, include_cert_loss=True)

Function called during hyperparameter optimization (HPO) using SMAC3, returns the loss.

Parameters:

Name Type Description Default
config dict

Configuration of hyperparameters.

required
seed int

Seed used.

required
epochs int

Number of epochs for training.

required
train_loader DataLoader

DataLoader for training data.

required
val_loader DataLoader

DataLoader for validation data.

required
output_dir str

Directory to save output.

required
cert_eval_samples int

Number of samples for certification evaluation.

1000
include_nat_loss bool

Whether to include natural loss into HPO loss.

True
include_adv_loss bool

Whether to include adversarial loss into HPO loss.

True
include_cert_loss bool

Whether to include certification loss into HPO loss.

True

Returns:

Name Type Description
tuple

Loss and dictionary of accuracies that is saved as information to the run by SMAC3.

Source code in CTRAIN/model_wrappers/mtl_ibp_model_wrapper.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def _hpo_runner(self, config, seed, epochs, train_loader, val_loader, output_dir, cert_eval_samples=1000, include_nat_loss=True, include_adv_loss=True, include_cert_loss=True):
    """
    Function called during hyperparameter optimization (HPO) using SMAC3, returns the loss.

    Args:
        config (dict): Configuration of hyperparameters.
        seed (int): Seed used.
        epochs (int): Number of epochs for training.
        train_loader (torch.utils.data.DataLoader): DataLoader for training data.
        val_loader (torch.utils.data.DataLoader): DataLoader for validation data.
        output_dir (str): Directory to save output.
        cert_eval_samples (int, optional): Number of samples for certification evaluation.
        include_nat_loss (bool, optional): Whether to include natural loss into HPO loss.
        include_adv_loss (bool, optional): Whether to include adversarial loss into HPO loss.
        include_cert_loss (bool, optional): Whether to include certification loss into HPO loss.

    Returns:
        tuple: Loss and dictionary of accuracies that is saved as information to the run by SMAC3.
    """
    config_hash = get_config_hash(config, 32)
    seed_ctrain(seed)

    if config['optimizer_func'] == 'adam':
        optimizer_func = torch.optim.Adam
    elif config['optimizer_func'] == 'radam':
        optimizer_func = torch.optim.RAdam
    if config['optimizer_func'] == 'adamw':
        optimizer_func = torch.optim.AdamW

    lr_decay_milestones = [
        config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'],
        config['warm_up_epochs'] + config['ramp_up_epochs'] + config['lr_decay_epoch_1'] + config['lr_decay_epoch_2']
    ]

    model_wrapper = MTLIBPModelWrapper(
        model=copy.deepcopy(self.original_model), 
        input_shape=self.input_shape,
        eps=self.eps,
        num_epochs=epochs, 
        bound_opts=self.bound_opts,
        checkpoint_save_path=None,
        device=self.device,
        train_eps_factor=config['train_eps_factor'],
        optimizer_func=optimizer_func,
        lr=config['learning_rate'],
        warm_up_epochs=config['warm_up_epochs'],
        ramp_up_epochs=config['ramp_up_epochs'],
        gradient_clip=10,
        lr_decay_factor=config['lr_decay_factor'],
        lr_decay_milestones=[epoch for epoch in lr_decay_milestones if epoch <= epochs],
        l1_reg_weight=config['l1_reg_weight'],
        shi_reg_weight=config['shi_reg_weight'],
        shi_reg_decay=config['shi_reg_decay'],
        mtl_ibp_alpha=config['mtl_ibp:mtl_ibp_alpha'],
        pgd_alpha=config['mtl_ibp:pgd_alpha'],
        pgd_early_stopping=False,
        pgd_restarts=config['mtl_ibp:pgd_restarts'],
        pgd_steps=config['mtl_ibp:pgd_steps'],
        pgd_eps_factor=config['mtl_ibp:mtl_ibp_eps_factor'],
        pgd_decay_milestones=()
    )

    model_wrapper.train_model(train_loader=train_loader)
    torch.save(model_wrapper.state_dict(), f'{output_dir}/nets/{config_hash}.pt')
    model_wrapper.eval()
    std_acc, cert_acc, adv_acc = model_wrapper.evaluate(test_loader=val_loader, test_samples=cert_eval_samples)

    loss = 0
    if include_nat_loss:
        loss -= std_acc
    if include_adv_loss:
        loss -= adv_acc
    if include_cert_loss:
        loss -= cert_acc

    return loss, {'nat_acc': std_acc, 'adv_acc': adv_acc, 'cert_acc': cert_acc}

train_model(train_loader, val_loader=None, start_epoch=0, end_epoch=None)

Trains the model using the MTL-IBP method.

Parameters:

Name Type Description Default
train_loader DataLoader

DataLoader for training data.

required
val_loader DataLoader

DataLoader for validation data.

None
start_epoch int

Epoch to start training from. Initialises learning rate and epsilon schedulers accordingly. Defaults to 0.

0
end_epoch int

Epoch to prematurely end training at. Defaults to None.

None

Returns:

Type Description
BoundedModule

Trained model.

Source code in CTRAIN/model_wrappers/mtl_ibp_model_wrapper.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=None):
    """
    Trains the model using the MTL-IBP method.

    Args:
        train_loader (torch.utils.data.DataLoader): DataLoader for training data.
        val_loader (torch.utils.data.DataLoader, optional): DataLoader for validation data.
        start_epoch (int, optional): Epoch to start training from. Initialises learning rate and epsilon schedulers accordingly. Defaults to 0.
        end_epoch (int, optional): Epoch to prematurely end training at. Defaults to None.

    Returns:
        (auto_LiRPA.BoundedModule): Trained model.
    """
    eps_std = self.train_eps / train_loader.std if train_loader.normalised else torch.tensor(self.train_eps)
    eps_std = torch.reshape(eps_std, (*eps_std.shape, 1, 1))

    trained_model = mtl_ibp_train_model(
        original_model=self.original_model,
        hardened_model=self.bounded_model,
        train_loader=train_loader,
        val_loader=val_loader,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        num_epochs=self.num_epochs,
        eps=self.train_eps,
        eps_std=eps_std,
        eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
        optimizer=self.optimizer,
        lr_decay_schedule=self.lr_decay_milestones,
        lr_decay_factor=self.lr_decay_factor,
        n_classes=self.n_classes,
        gradient_clip=self.gradient_clip,
        l1_regularisation_weight=self.l1_reg_weight,
        shi_regularisation_weight=self.shi_reg_weight,
        shi_reg_decay=self.shi_reg_decay,
        alpha=self.mtl_ibp_alpha,
        pgd_n_steps=self.pgd_steps,
        pgd_step_size=self.pgd_alpha,
        pgd_restarts=self.pgd_restarts,
        pgd_eps_factor=self.pgd_eps_factor,
        pgd_early_stopping=self.pgd_early_stopping,
        pgd_decay_factor=self.pgd_alpha_decay_factor,
        pgd_decay_checkpoints=self.pgd_decay_milestones,
        results_path=self.checkpoint_path,
        checkpoint_save_interval=self.checkpoint_save_interval,
        device=self.device
    )

    return trained_model