shi
get_shi_regulariser(model, ptb, data, target, eps_scheduler, n_classes, device, tolerance=0.5, verbose=False, included_regularisers=['relu', 'tightness'], regularisation_decay=True)
Compute the Shi regularisation loss for a given model. See Shi et al. (2020) for more details.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
BoundedModule
|
The bounded model. IMPORTANT: Do not pass the original model, but the hardened model. |
required |
ptb
|
PerturbationLpNorm
|
The perturbation applied to the input data. |
required |
data
|
Tensor
|
Input data tensor. |
required |
target
|
Tensor
|
Target labels tensor. |
required |
eps_scheduler
|
BaseScheduler
|
Scheduler for epsilon values. |
required |
n_classes
|
int
|
Number of classes in the classification task. |
required |
device
|
device
|
Device to perform computations on (e.g., 'cpu' or 'cuda'). |
required |
tolerance
|
float
|
Tolerance value for regularisation. Default is 0.5. |
0.5
|
verbose
|
bool
|
If True, prints detailed information during computation. Default is False. |
False
|
included_regularisers
|
list of str
|
List of regularisers to include in the loss computation. Default is ['relu', 'tightness']. |
['relu', 'tightness']
|
regularisation_decay
|
bool
|
If True, applies decay to the regularisation loss. Default is True. |
True
|
Returns:
Type | Description |
---|---|
torch.Tensor: The computed SHI regulariser loss. |
Source code in CTRAIN/train/certified/regularisers/shi.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
|