Skip to content

taps

GradExpander

Bases: Function

A custom autograd function that scales the gradient during the backward pass. This function allows you to define a custom forward and backward pass for a PyTorch operation. The forward pass simply returns the input tensor, while the backward pass scales the gradient by a specified factor alpha. Methods: forward(ctx, x, alpha: float = 1): backward(ctx, grad_x):

Source code in CTRAIN/bound/taps.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class GradExpander(torch.autograd.Function):
    """
    A custom autograd function that scales the gradient during the backward pass.
    This function allows you to define a custom forward and backward pass for a 
    PyTorch operation. The forward pass simply returns the input tensor, while 
    the backward pass scales the gradient by a specified factor `alpha`.
    Methods:
        forward(ctx, x, alpha: float = 1):
        backward(ctx, grad_x):

    """

    @staticmethod
    def forward(ctx, x, alpha:float=1):
        """
        Forward pass for the custom operation.

        Args:
            ctx: The context object that can be used to stash information
                for backward computation.
            x: The input tensor.
            alpha (float, optional): A scaling factor. Defaults to 1.

        Returns:
            (torch.Tensor): The input tensor `x`.
        """
        ctx.alpha = alpha
        return x

    @staticmethod
    def backward(ctx, grad_x):
        """
        Performs the backward pass for the custom autograd function.

        Args:
            ctx: The context object that can be used to stash information for backward computation.
            grad_x: The gradient of the loss with respect to the output of the forward pass.

        Returns:
            (Tuple[Tensor, None]): A tuple containing the gradient of the loss with respect to the input of the forward pass and None (as there is no gradient with respect to the second input).
        """
        return ctx.alpha * grad_x, None

backward(ctx, grad_x) staticmethod

Performs the backward pass for the custom autograd function.

Parameters:

Name Type Description Default
ctx

The context object that can be used to stash information for backward computation.

required
grad_x

The gradient of the loss with respect to the output of the forward pass.

required

Returns:

Type Description
Tuple[Tensor, None]

A tuple containing the gradient of the loss with respect to the input of the forward pass and None (as there is no gradient with respect to the second input).

Source code in CTRAIN/bound/taps.py
323
324
325
326
327
328
329
330
331
332
333
334
335
@staticmethod
def backward(ctx, grad_x):
    """
    Performs the backward pass for the custom autograd function.

    Args:
        ctx: The context object that can be used to stash information for backward computation.
        grad_x: The gradient of the loss with respect to the output of the forward pass.

    Returns:
        (Tuple[Tensor, None]): A tuple containing the gradient of the loss with respect to the input of the forward pass and None (as there is no gradient with respect to the second input).
    """
    return ctx.alpha * grad_x, None

forward(ctx, x, alpha=1) staticmethod

Forward pass for the custom operation.

Parameters:

Name Type Description Default
ctx

The context object that can be used to stash information for backward computation.

required
x

The input tensor.

required
alpha float

A scaling factor. Defaults to 1.

1

Returns:

Type Description
Tensor

The input tensor x.

Source code in CTRAIN/bound/taps.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
@staticmethod
def forward(ctx, x, alpha:float=1):
    """
    Forward pass for the custom operation.

    Args:
        ctx: The context object that can be used to stash information
            for backward computation.
        x: The input tensor.
        alpha (float, optional): A scaling factor. Defaults to 1.

    Returns:
        (torch.Tensor): The input tensor `x`.
    """
    ctx.alpha = alpha
    return x

Bases: Function

RectifiedLinearGradientLink is a custom autograd function that establishes a rectified linear gradient link between the IBP bounds of the feature extractor (lb, ub) and the PGD bounds (x_adv) of the classifier. This function is not a valid gradient with respect to the forward function.

Attributes:

Name Type Description
c float

A constant used to determine the slope.

tol float

A tolerance value to avoid division by zero.

Methods:

Name Description
forward

float, tol: float)

backward
Source code in CTRAIN/bound/taps.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class RectifiedLinearGradientLink(torch.autograd.Function):
    """
    RectifiedLinearGradientLink is a custom autograd function that establishes a rectified linear gradient link 
    between the IBP bounds of the feature extractor (lb, ub) and the 
    PGD bounds (x_adv) of the classifier. This function is not a valid gradient with respect 
    to the forward function.

    Attributes:
        c (float): A constant used to determine the slope.
        tol (float): A tolerance value to avoid division by zero.

    Methods:
        forward(ctx, lb, ub, x, c: float, tol: float)
        backward(ctx, grad_x):
    """
    @staticmethod
    def forward(ctx, lb, ub, x, c:float, tol:float):
        """
        Saves the input tensors and constants for backward computation.

        Args:
            ctx: Context object to save information for backward computation.
            lb: Lower bound tensor.
            ub: Upper bound tensor.
            x: Input tensor.
            c (float): A constant used to determine the slope.
            tol (float): A tolerance value to avoid division by zero.

        Returns:
            (Tensor): The input tensor x.
        """
        ctx.save_for_backward(lb, ub, x)
        ctx.c = c
        ctx.tol = tol
        return x
    @staticmethod
    def backward(ctx, grad_x):
        """
        Computes the gradient of the loss with respect to the input bounds (lb, ub).

        Args:
            ctx: Context object containing saved tensors and constants.
            grad_x: Gradient of the loss with respect to the output of the forward function.

        Returns:
            (Tuple[Tensor, Tensor, None, None, None]): Gradients with respect to lb, ub, and None for other inputs.
        """
        lb, ub, x = ctx.saved_tensors
        c, tol = ctx.c, ctx.tol
        slackness = c * (ub - lb)
        # handle grad w.r.t. ub
        thre = (ub - slackness)
        Rectifiedgrad_mask = (x >= thre)
        grad_ub = (Rectifiedgrad_mask * grad_x * (x - thre).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
        # handle grad w.r.t. lb
        thre = (lb + slackness)
        Rectifiedgrad_mask = (x <= thre)
        grad_lb = (Rectifiedgrad_mask * grad_x * (thre - x).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
        # we don't need grad w.r.t. x and param
        return grad_lb, grad_ub, None, None, None

backward(ctx, grad_x) staticmethod

Computes the gradient of the loss with respect to the input bounds (lb, ub).

Parameters:

Name Type Description Default
ctx

Context object containing saved tensors and constants.

required
grad_x

Gradient of the loss with respect to the output of the forward function.

required

Returns:

Type Description
Tuple[Tensor, Tensor, None, None, None]

Gradients with respect to lb, ub, and None for other inputs.

Source code in CTRAIN/bound/taps.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
@staticmethod
def backward(ctx, grad_x):
    """
    Computes the gradient of the loss with respect to the input bounds (lb, ub).

    Args:
        ctx: Context object containing saved tensors and constants.
        grad_x: Gradient of the loss with respect to the output of the forward function.

    Returns:
        (Tuple[Tensor, Tensor, None, None, None]): Gradients with respect to lb, ub, and None for other inputs.
    """
    lb, ub, x = ctx.saved_tensors
    c, tol = ctx.c, ctx.tol
    slackness = c * (ub - lb)
    # handle grad w.r.t. ub
    thre = (ub - slackness)
    Rectifiedgrad_mask = (x >= thre)
    grad_ub = (Rectifiedgrad_mask * grad_x * (x - thre).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
    # handle grad w.r.t. lb
    thre = (lb + slackness)
    Rectifiedgrad_mask = (x <= thre)
    grad_lb = (Rectifiedgrad_mask * grad_x * (thre - x).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
    # we don't need grad w.r.t. x and param
    return grad_lb, grad_ub, None, None, None

forward(ctx, lb, ub, x, c, tol) staticmethod

Saves the input tensors and constants for backward computation.

Parameters:

Name Type Description Default
ctx

Context object to save information for backward computation.

required
lb

Lower bound tensor.

required
ub

Upper bound tensor.

required
x

Input tensor.

required
c float

A constant used to determine the slope.

required
tol float

A tolerance value to avoid division by zero.

required

Returns:

Type Description
Tensor

The input tensor x.

Source code in CTRAIN/bound/taps.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
@staticmethod
def forward(ctx, lb, ub, x, c:float, tol:float):
    """
    Saves the input tensors and constants for backward computation.

    Args:
        ctx: Context object to save information for backward computation.
        lb: Lower bound tensor.
        ub: Upper bound tensor.
        x: Input tensor.
        c (float): A constant used to determine the slope.
        tol (float): A tolerance value to avoid division by zero.

    Returns:
        (Tensor): The input tensor x.
    """
    ctx.save_for_backward(lb, ub, x)
    ctx.c = c
    ctx.tol = tol
    return x

_get_bound_estimation_from_pts(block, pts, dim_to_estimate, C=None)

Estimate bounds for specified dimensions from given adversarial examples.

Parameters:

Name Type Description Default
block BoundedModule

The neural network block for which to estimate pivotal points.

required
pts Tensor

Tensor of adversarial examples of shape (batch_size, num_pivotal, *shape_in[1:]).

required
dim_to_estimate Tensor

Tensor indicating the dimensions to estimate, shape (batch_size, num_dims, dim_size).

required
C Tensor

Matrix specifying the correct class for bound margin calculation. Must be provided.

None

Returns:

Name Type Description
estimated_bounds Tensor

Estimated bounds tensor of shape (batch_size, num_pivotal) if C is None, otherwise shape (batch_size, n_class).

Source code in CTRAIN/bound/taps.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def _get_bound_estimation_from_pts(block, pts, dim_to_estimate, C=None):
    """
    Estimate bounds for specified dimensions from given adversarial examples.

    Parameters:
        block (autoLiRPA.BoundedModule): The neural network block for which to estimate pivotal points.
        pts (torch.Tensor): Tensor of adversarial examples of shape (batch_size, num_pivotal, *shape_in[1:]).
        dim_to_estimate (torch.Tensor): Tensor indicating the dimensions to estimate, shape (batch_size, num_dims, dim_size).
        C (torch.Tensor): Matrix specifying the correct class for bound margin calculation. Must be provided.

    Returns:
        estimated_bounds(torch.Tensor): Estimated bounds tensor of shape (batch_size, num_pivotal) if C is None,
                    otherwise shape (batch_size, n_class).
    """

    if C is None:
        # pts shape (batch_size, num_pivotal, *shape_in[1:])
        out_pts = block(pts.reshape(-1, *pts.shape[2:]))
        out_pts = out_pts.reshape(*pts.shape[:2], -1)
        dim_to_estimate = dim_to_estimate.unsqueeze(1)
        dim_to_estimate = dim_to_estimate.expand(dim_to_estimate.shape[0], out_pts.shape[1], dim_to_estimate.shape[2])
        out_pts = torch.gather(out_pts, dim=2, index=dim_to_estimate) # shape: (batch_size, num_pivotal, num_pivotal)
        estimated_bounds = torch.diagonal(out_pts, dim1=1, dim2=2) # shape: (batch_size, num_pivotal)
    else:
        # # main idea: convert the 9 adv inputs into one batch to compute the bound at the same time; involve many reshaping
        batch_C = C.unsqueeze(1).expand(-1, pts.shape[1], -1, -1).reshape(-1, *(C.shape[1:])) # may need shape adjustment
        batch_pts = pts.reshape(-1, *(pts.shape[2:]))
        out_pts = block(batch_pts)
        out_pts = torch.bmm(batch_C, out_pts.unsqueeze(-1)).squeeze(-1)
        out_pts = out_pts.reshape(*(pts.shape[:2]), *(out_pts.shape[1:]))
        out_pts = - out_pts # the out is the lower bound of yt - yi, transform it to the upper bound of yi - yt
        # the out_pts should be in shape (batch_size, n_class - 1, n_class - 1)
        ub = torch.diagonal(out_pts, dim1=1, dim2=2) # shape: (batch_size, n_class - 1)
        estimated_bounds = torch.cat([torch.zeros(size=(ub.shape[0],1), dtype=ub.dtype, device=ub.device), ub], dim=1) # shape: (batch_size, n_class)

    return estimated_bounds

_get_pivotal_points(block, input_lb, input_ub, pgd_steps, pgd_restarts, pgd_step_size, pgd_decay_factor, pgd_decay_checkpoints, n_classes, C=None)

Estimate pivotal points for the classifier block using Projected Gradient Descent (PGD).

Parameters:

Name Type Description Default
block BoundedModule

The neural network block for which to estimate pivotal points.

required
input_lb Tensor

Lower bound of the input to the network block.

required
input_ub Tensor

Upper bound of the input to the network block.

required
pgd_steps int

Number of PGD steps to perform.

required
pgd_restarts int

Number of PGD restarts to perform.

required
pgd_step_size float

Step size for PGD.

required
pgd_decay_factor float

Decay factor for PGD step size.

required
pgd_decay_checkpoints list of int

Checkpoints at which to decay the PGD step size.

required
n_classes int

Number of classes in the classification task.

required
C Tensor

Matrix specifying the correct class for bound margin calculation. Must be provided.

None

Returns:

Type Description
list of torch.Tensor

List containing the concatenated pivotal points tensor.

Source code in CTRAIN/bound/taps.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def _get_pivotal_points(block, input_lb, input_ub, pgd_steps, pgd_restarts, pgd_step_size, pgd_decay_factor, pgd_decay_checkpoints, n_classes, C=None):
    """
    Estimate pivotal points for the classifier block using Projected Gradient Descent (PGD).

    Parameters:
        block (autoLiRPA.BoundedModule): The neural network block for which to estimate pivotal points.
        input_lb (torch.Tensor): Lower bound of the input to the network block.
        input_ub (torch.Tensor): Upper bound of the input to the network block.
        pgd_steps (int): Number of PGD steps to perform.
        pgd_restarts (int): Number of PGD restarts to perform.
        pgd_step_size (float): Step size for PGD.
        pgd_decay_factor (float): Decay factor for PGD step size.
        pgd_decay_checkpoints (list of int): Checkpoints at which to decay the PGD step size.
        n_classes (int): Number of classes in the classification task.
        C (torch.Tensor, optional): Matrix specifying the correct class for bound margin calculation. Must be provided.

    Returns:
        (list of torch.Tensor): List containing the concatenated pivotal points tensor.
    """
    assert C is not None # Should only estimate for the final block
    lb, ub = input_lb.clone().detach(), input_ub.clone().detach()

    pt_list = []
    # split into batches
    # TODO: Can we keep this fixed batch size?
    bs = 128
    lb_batches = [lb[i*bs:(i+1)*bs] for i in range(math.ceil(len(lb) / bs))]
    ub_batches = [ub[i*bs:(i+1)*bs] for i in range(math.ceil(len(ub) / bs))]
    C_batches = [C[i*bs:(i+1)*bs] for i in range(math.ceil(len(C) / bs))]
    for lb_one_batch, ub_one_batch, C_one_batch in zip(lb_batches, ub_batches, C_batches):
        pt_list.append(_get_pivotal_points_one_batch(block, lb_one_batch, ub_one_batch, pgd_steps, pgd_restarts, pgd_step_size, pgd_decay_factor, pgd_decay_checkpoints, n_classes=n_classes, C=C_one_batch))
    pts = torch.cat(pt_list, dim=0)
    return [pts, ]

_get_pivotal_points_one_batch(block, lb, ub, pgd_steps, pgd_restarts, pgd_step_size, pgd_decay_factor, pgd_decay_checkpoints, C, n_classes, device='cuda')

Estimate pivotal points for a batch using Projected Gradient Descent (PGD).

Parameters:

Name Type Description Default
block BoundedModule

The neural network block for which to estimate pivotal points.

required
lb Tensor

Lower bound of the input.

required
ub Tensor

Upper bound of the input.

required
pgd_steps int

Number of PGD steps.

required
pgd_restarts int

Number of PGD restarts.

required
pgd_step_size float

Step size for PGD.

required
pgd_decay_factor float

Decay factor for learning rate.

required
pgd_decay_checkpoints list

Checkpoints for learning rate decay.

required
C Tensor

Matrix specifying the correct class for bound margin calculation. Must be provided.

required
n_classes int

Number of classes.

required
device str

Device to perform computations on. Default is 'cuda'.

'cuda'

Returns:

Type Description
Tensor

Adversarial examples per class for whole batch.

Source code in CTRAIN/bound/taps.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def _get_pivotal_points_one_batch(block, lb, ub, pgd_steps, pgd_restarts, pgd_step_size, pgd_decay_factor, pgd_decay_checkpoints, C, n_classes, device='cuda'):
    """
    Estimate pivotal points for a batch using Projected Gradient Descent (PGD).

    Args:
        block (autoLiRPA.BoundedModule): The neural network block for which to estimate pivotal points.
        lb (torch.Tensor): Lower bound of the input.
        ub (torch.Tensor): Upper bound of the input.
        pgd_steps (int): Number of PGD steps.
        pgd_restarts (int): Number of PGD restarts.
        pgd_step_size (float): Step size for PGD.
        pgd_decay_factor (float): Decay factor for learning rate.
        pgd_decay_checkpoints (list): Checkpoints for learning rate decay.
        C (torch.Tensor): Matrix specifying the correct class for bound margin calculation. Must be provided.
        n_classes (int): Number of classes.
        device (str, optional): Device to perform computations on. Default is 'cuda'.

    Returns:
        (torch.Tensor): Adversarial examples per class for whole batch.
    """

    num_pivotal = n_classes - 1 # only need to estimate n_class - 1 dim for the final output

    def init_pts(input_lb, input_ub):
        rand_init = input_lb.unsqueeze(1) + (input_ub-input_lb).unsqueeze(1)*torch.rand(input_lb.shape[0], num_pivotal, *input_lb.shape[1:], device=device)
        return rand_init

    def select_schedule(num_steps):
        if num_steps >= 20 and num_steps <= 50:
            lr_decay_milestones = [int(num_steps*0.7)]
        elif num_steps > 50 and num_steps <= 80:
            lr_decay_milestones = [int(num_steps*0.4), int(num_steps*0.7)]
        elif num_steps > 80:
            lr_decay_milestones = [int(num_steps*0.3), int(num_steps*0.6), int(num_steps*0.8)]
        else:
            lr_decay_milestones = []
        return lr_decay_milestones

    lr_decay_milestones = pgd_decay_checkpoints
    lr_decay_factor = pgd_decay_factor
    init_lr = pgd_step_size

    retain_graph = False
    pts = init_pts(lb, ub)
    variety = (ub - lb).unsqueeze(1).detach()
    best_estimation = -1e5*torch.ones(pts.shape[:2], device=pts.device)
    best_pts = torch.zeros_like(pts)
    with torch.enable_grad():
        for re in range(pgd_restarts):
            lr = init_lr
            pts = init_pts(lb, ub)
            for it in range(pgd_steps+1):
                pts.requires_grad = True
                estimated_pseudo_bound = _get_bound_estimation_from_pts(block, pts, None, C=C)
                improve_idx = estimated_pseudo_bound[:, 1:] > best_estimation
                best_estimation[improve_idx] = estimated_pseudo_bound[:, 1:][improve_idx].detach()
                best_pts[improve_idx] = pts[improve_idx].detach()
                # wants to maximize the estimated bound
                if it != pgd_steps:
                    loss = - estimated_pseudo_bound.sum()
                    loss.backward(retain_graph=retain_graph)
                    new_pts = pts - pts.grad.sign() * lr * variety
                    pts = torch.max(torch.min(new_pts, ub.unsqueeze(1)), lb.unsqueeze(1)).detach()
                    if (it+1) in lr_decay_milestones:
                        lr *= lr_decay_factor
    return best_pts.detach()

bound_taps(original_model, hardened_model, bounded_blocks, data, target, n_classes, ptb, device='cuda', pgd_steps=20, pgd_restarts=1, pgd_step_size=0.2, pgd_decay_factor=0.2, pgd_decay_checkpoints=(5, 7), gradient_link_thresh=0.5, gradient_link_tolerance=1e-05, propagation='IBP', sabr_args=None)

Compute the bounds of the model's output using the TAPS method.

Parameters:

Name Type Description Default
original_model Module

The original neural network model.

required
hardened_model BoundedModule

The auto_LiRPA model.

required
bounded_blocks list

List of bounded blocks of the model.

required
data Tensor

The input data tensor.

required
target Tensor

The target labels tensor.

required
n_classes int

The number of classes for classification.

required
ptb PerturbationLpNorm

The perturbation object defining the perturbation set.

required
device str

The device to run the computation on. Default is 'cuda'.

'cuda'
pgd_steps int

The number of steps for the PGD attack. Default is 20.

20
pgd_restarts int

The number of restarts for the PGD attack. Default is 1.

1
pgd_step_size float

The step size for the PGD attack. Default is 0.2.

0.2
pgd_decay_factor float

The decay factor for the PGD attack. Default is 0.2.

0.2
pgd_decay_checkpoints tuple

The decay checkpoints for the PGD attack. Default is (5, 7).

(5, 7)
gradient_link_thresh float

The threshold for gradient linking. Default is 0.5.

0.5
gradient_link_tolerance float

The tolerance for gradient linking. Default is 1e-05.

1e-05
propagation str

The propagation method to use ('IBP' or 'SABR'). Default is 'IBP'.

'IBP'
sabr_args dict

The arguments for the SABR method. Default is None.

None

Returns:

Name Type Description
taps_bound Tuple[Tensor, Tensor]

The TAPS bounds of the model's output.

Source code in CTRAIN/bound/taps.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def bound_taps(original_model, hardened_model, bounded_blocks, data, target, n_classes, ptb, device='cuda', pgd_steps=20, pgd_restarts=1, pgd_step_size=.2, 
               pgd_decay_factor=.2, pgd_decay_checkpoints=(5,7),
               gradient_link_thresh=.5, gradient_link_tolerance=1e-05, propagation="IBP", sabr_args=None):
    """
    Compute the bounds of the model's output using the TAPS method.

    Parameters:
        original_model (torch.nn.Module): The original neural network model.
        hardened_model (autoLiRPA.BoundedModule): The auto_LiRPA model.
        bounded_blocks (list): List of bounded blocks of the model.
        data (Tensor): The input data tensor.
        target (Tensor): The target labels tensor.
        n_classes (int): The number of classes for classification.
        ptb (auto_LiRPA.PerturbationLpNorm): The perturbation object defining the perturbation set.
        device (str, optional): The device to run the computation on. Default is 'cuda'.
        pgd_steps (int, optional): The number of steps for the PGD attack. Default is 20.
        pgd_restarts (int, optional): The number of restarts for the PGD attack. Default is 1.
        pgd_step_size (float, optional): The step size for the PGD attack. Default is 0.2.
        pgd_decay_factor (float, optional): The decay factor for the PGD attack. Default is 0.2.
        pgd_decay_checkpoints (tuple, optional): The decay checkpoints for the PGD attack. Default is (5, 7).
        gradient_link_thresh (float, optional): The threshold for gradient linking. Default is 0.5.
        gradient_link_tolerance (float, optional): The tolerance for gradient linking. Default is 1e-05.
        propagation (str, optional): The propagation method to use ('IBP' or 'SABR'). Default is 'IBP'.
        sabr_args (dict, optional): The arguments for the SABR method. Default is None.

    Returns:
        taps_bound(Tuple[Tensor, Tensor]): The TAPS bounds of the model's output.
    """
    assert len(bounded_blocks) == 2, "Split not supported!"

    if propagation == 'IBP':
        lb, ub = bound_ibp(
            model=bounded_blocks[0],
            ptb=ptb,
            data=data,
            target=None,
            n_classes=n_classes,
        )
    if propagation == 'SABR':
        assert sabr_args is not None, "Need to Provide SABR arguments if you choose SABR for propagation"
        lb, ub = bound_sabr(
            # Intermediate Bound model instructs to return bounds after the first network block
            **{**sabr_args, "intermediate_bound_model": bounded_blocks[0], "return_adv_output": False},
        )

    with torch.no_grad():
        hardened_model.eval()
        original_model.eval()
        for block in bounded_blocks:
            block.eval()
        c = construct_c(data, target, n_classes)
        with torch.no_grad():
            grad_cleaner = torch.optim.SGD(hardened_model.parameters())
            adv_samples = _get_pivotal_points(bounded_blocks[1], lb, ub, pgd_steps, pgd_restarts, pgd_step_size, pgd_decay_factor, pgd_decay_checkpoints, n_classes, C=c)
            grad_cleaner.zero_grad()

        hardened_model.train()
        original_model.train()
        for block in bounded_blocks:
            block.train()

    pts = adv_samples[0].detach()
    pts = torch.transpose(pts, 0, 1)
    pts = RectifiedLinearGradientLink.apply(lb.unsqueeze(0), ub.unsqueeze(0), pts, gradient_link_thresh, gradient_link_tolerance)
    pts = torch.transpose(pts, 0, 1)
    pgd_bounds = _get_bound_estimation_from_pts(bounded_blocks[1], pts, None, c)
    # NOTE: VERY IMPORTANT CHANGES TO TAPS BOUND TO BE COMPATIBLE WITH CTRAIN WORKFLOW
    pgd_bounds = pgd_bounds[:, 1:]
    pgd_bounds = -pgd_bounds


    ibp_lb, ibp_ub = bound_ibp(
        model=bounded_blocks[1],
        ptb=PerturbationLpNorm(x_L=lb, x_U=ub),
        data=data,
        target=target,
        n_classes=n_classes,
    )

    return pgd_bounds, ibp_lb