taps
get_taps_loss(original_model, hardened_model, bounded_blocks, criterion, data, target, n_classes, ptb, device='cuda', pgd_steps=8, pgd_restarts=1, pgd_step_size=None, pgd_decay_factor=0.2, pgd_decay_checkpoints=(5, 7), gradient_link_thresh=0.5, gradient_link_tolerance=1e-05, gradient_expansion_alpha=5, propagation='IBP', sabr_args=None, return_bounds=False, return_stats=False)
Compute the TAPS loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hardened_model
|
BoundedModule
|
The bounded model to be trained. |
required |
original_model
|
Module
|
The original model. |
required |
bounded_blocks
|
list
|
List of the LiRPA blocks needed for TAPS (feature extractor and classifier). |
required |
criterion
|
callable
|
Loss function to be used. |
required |
data
|
Tensor
|
Input data. |
required |
target
|
Tensor
|
Target labels. |
required |
n_classes
|
int
|
Number of classes in the classification task. |
required |
ptb
|
PerturbationLpNorm
|
The perturbation applied to the input data. |
required |
device
|
str
|
Device to run the computation on. Default is 'cuda'. |
'cuda'
|
pgd_steps
|
int
|
Number of PGD steps. Default is 8. |
8
|
pgd_restarts
|
int
|
Number of PGD restarts. Default is 1. |
1
|
pgd_step_size
|
float
|
Step size for PGD. Default is None. |
None
|
pgd_decay_factor
|
float
|
Decay factor for PGD step size. Default is 0.2. |
0.2
|
pgd_decay_checkpoints
|
tuple
|
Checkpoints for PGD decay. Default is (5, 7). |
(5, 7)
|
gradient_link_thresh
|
float
|
Threshold for gradient linking. Default is 0.5. |
0.5
|
gradient_link_tolerance
|
float
|
Tolerance for gradient linking. Default is 1e-05. |
1e-05
|
gradient_expansion_alpha
|
float
|
Alpha value for gradient expansion. Default is 5. |
5
|
propagation
|
str
|
Propagation method to be used (SABR or IBP). Default is "IBP". |
'IBP'
|
sabr_args
|
dict
|
Additional arguments for SABR. Default is None. |
None
|
return_bounds
|
bool
|
Whether to return bounds. Default is False. |
False
|
return_stats
|
bool
|
Whether to return statistics. Default is False. |
False
|
Returns:
Type | Description |
---|---|
tuple
|
A tuple containing the loss, and optionally the bounds and robust error statistics. |
Source code in CTRAIN/train/certified/losses/taps.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|