Skip to content

taps

get_taps_loss(original_model, hardened_model, bounded_blocks, criterion, data, target, n_classes, ptb, device='cuda', pgd_steps=8, pgd_restarts=1, pgd_step_size=None, pgd_decay_factor=0.2, pgd_decay_checkpoints=(5, 7), gradient_link_thresh=0.5, gradient_link_tolerance=1e-05, gradient_expansion_alpha=5, propagation='IBP', sabr_args=None, return_bounds=False, return_stats=False)

Compute the TAPS loss.

Parameters:

Name Type Description Default
hardened_model BoundedModule

The bounded model to be trained.

required
original_model Module

The original model.

required
bounded_blocks list

List of the LiRPA blocks needed for TAPS (feature extractor and classifier).

required
criterion callable

Loss function to be used.

required
data Tensor

Input data.

required
target Tensor

Target labels.

required
n_classes int

Number of classes in the classification task.

required
ptb PerturbationLpNorm

The perturbation applied to the input data.

required
device str

Device to run the computation on. Default is 'cuda'.

'cuda'
pgd_steps int

Number of PGD steps. Default is 8.

8
pgd_restarts int

Number of PGD restarts. Default is 1.

1
pgd_step_size float

Step size for PGD. Default is None.

None
pgd_decay_factor float

Decay factor for PGD step size. Default is 0.2.

0.2
pgd_decay_checkpoints tuple

Checkpoints for PGD decay. Default is (5, 7).

(5, 7)
gradient_link_thresh float

Threshold for gradient linking. Default is 0.5.

0.5
gradient_link_tolerance float

Tolerance for gradient linking. Default is 1e-05.

1e-05
gradient_expansion_alpha float

Alpha value for gradient expansion. Default is 5.

5
propagation str

Propagation method to be used (SABR or IBP). Default is "IBP".

'IBP'
sabr_args dict

Additional arguments for SABR. Default is None.

None
return_bounds bool

Whether to return bounds. Default is False.

False
return_stats bool

Whether to return statistics. Default is False.

False

Returns:

Type Description
tuple

A tuple containing the loss, and optionally the bounds and robust error statistics.

Source code in CTRAIN/train/certified/losses/taps.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def get_taps_loss(original_model, hardened_model, bounded_blocks, criterion, data, target, n_classes, ptb, device='cuda', pgd_steps=8, pgd_restarts=1, 
                  pgd_step_size=None, pgd_decay_factor=.2, pgd_decay_checkpoints=(5,7), gradient_link_thresh=.5,
                  gradient_link_tolerance=1e-05, gradient_expansion_alpha=5, propagation="IBP", sabr_args=None, return_bounds=False, return_stats=False):

    """
    Compute the TAPS loss.

    Parameters:
        hardened_model (auto_LiRPA.BoundedModule): The bounded model to be trained.
        original_model (torch.nn.Module): The original model.
        bounded_blocks (list): List of the LiRPA blocks needed for TAPS (feature extractor and classifier).
        criterion (callable): Loss function to be used.
        data (torch.Tensor): Input data.
        target (torch.Tensor): Target labels.
        n_classes (int): Number of classes in the classification task.
        ptb (autoLiRPA.PerturbationLpNorm): The perturbation applied to the input data.
        device (str, optional): Device to run the computation on. Default is 'cuda'.
        pgd_steps (int, optional): Number of PGD steps. Default is 8.
        pgd_restarts (int, optional): Number of PGD restarts. Default is 1.
        pgd_step_size (float, optional): Step size for PGD. Default is None.
        pgd_decay_factor (float, optional): Decay factor for PGD step size. Default is 0.2.
        pgd_decay_checkpoints (tuple, optional): Checkpoints for PGD decay. Default is (5, 7).
        gradient_link_thresh (float, optional): Threshold for gradient linking. Default is 0.5.
        gradient_link_tolerance (float, optional): Tolerance for gradient linking. Default is 1e-05.
        gradient_expansion_alpha (float, optional): Alpha value for gradient expansion. Default is 5.
        propagation (str, optional): Propagation method to be used (SABR or IBP). Default is "IBP".
        sabr_args (dict, optional): Additional arguments for SABR. Default is None.
        return_bounds (bool, optional): Whether to return bounds. Default is False.
        return_stats (bool, optional): Whether to return statistics. Default is False.

    Returns:
        (tuple): A tuple containing the loss, and optionally the bounds and robust error statistics.
    """
    assert len(bounded_blocks) == 2, "Split not supported!"

    taps_bound, ibp_bound = bound_taps(
        original_model=original_model,
        hardened_model=hardened_model,
        bounded_blocks=bounded_blocks,
        data=data,
        target=target,
        n_classes=n_classes,
        ptb=ptb,
        device=device,
        pgd_steps=pgd_steps,
        pgd_restarts=pgd_restarts,
        pgd_step_size=pgd_step_size, 
        pgd_decay_factor=pgd_decay_factor,
        pgd_decay_checkpoints=pgd_decay_checkpoints,
        gradient_link_thresh=gradient_link_thresh,
        gradient_link_tolerance=gradient_link_tolerance,
        propagation=propagation,
        sabr_args=sabr_args
    )

    taps_loss = get_loss_from_bounds(taps_bound, criterion)
    ibp_loss = get_loss_from_bounds(ibp_bound, criterion)

    loss = GradExpander.apply(taps_loss, gradient_expansion_alpha) * ibp_loss

    return_tuple = (loss,)

    if return_bounds:
        return_tuple = return_tuple + (taps_bound, None)
    if return_stats:
        robust_err = torch.sum((taps_bound < 0).any(dim=1)).item() / data.size(0)
        return_tuple = return_tuple + (robust_err,)

    return return_tuple