shi
            get_shi_regulariser(model, ptb, data, target, eps_scheduler, n_classes, device, tolerance=0.5, verbose=False, included_regularisers=['relu', 'tightness'], regularisation_decay=True, loss_fusion=False)
    Compute the Shi regularisation loss for a given model. See Shi et al. (2020) for more details.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
                model
             | 
            
                  BoundedModule
             | 
            
               The bounded model. IMPORTANT: Do not pass the original model, but the hardened model.  | 
            required | 
                ptb
             | 
            
                  PerturbationLpNorm
             | 
            
               The perturbation applied to the input data.  | 
            required | 
                data
             | 
            
                  Tensor
             | 
            
               Input data tensor.  | 
            required | 
                target
             | 
            
                  Tensor
             | 
            
               Target labels tensor.  | 
            required | 
                eps_scheduler
             | 
            
                  BaseScheduler
             | 
            
               Scheduler for epsilon values.  | 
            required | 
                n_classes
             | 
            
                  int
             | 
            
               Number of classes in the classification task.  | 
            required | 
                device
             | 
            
                  device
             | 
            
               Device to perform computations on (e.g., 'cpu' or 'cuda').  | 
            required | 
                tolerance
             | 
            
                  float
             | 
            
               Tolerance value for regularisation. Default is 0.5.  | 
            
                  0.5
             | 
          
                verbose
             | 
            
                  bool
             | 
            
               If True, prints detailed information during computation. Default is False.  | 
            
                  False
             | 
          
                included_regularisers
             | 
            
                  list of str
             | 
            
               List of regularisers to include in the loss computation. Default is ['relu', 'tightness'].  | 
            
                  ['relu', 'tightness']
             | 
          
                regularisation_decay
             | 
            
                  bool
             | 
            
               If True, applies decay to the regularisation loss. Default is True.  | 
            
                  True
             | 
          
                loss_fusion
             | 
            
                  bool
             | 
            
               If True, uses loss fusion. Default is False.  | 
            
                  False
             | 
          
Returns:
| Type | Description | 
|---|---|
| 
               torch.Tensor: The computed SHI regulariser loss.  | 
          
Source code in CTRAIN/train/certified/regularisers/shi.py
              10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121  |  |