Skip to content

ibp

get_ibp_loss(hardened_model, ptb, data, target, n_classes, criterion, return_bounds=False, return_stats=False)

Compute the Interval Bound Propagation (IBP) loss for a given model.

Parameters:

Name Type Description Default
hardened_model BoundedModule

The bounded model to be trained.

required
ptb PerturbationLpNorm

The perturbation applied to the input data.

required
data Tensor

Input data.

required
target Tensor

Target labels.

required
n_classes int

Number of classes.

required
criterion callable

Loss function to be used.

required
return_bounds bool

If True, return the lower and upper bounds. Default is False.

False
return_stats bool

If True, return additional statistics. Default is False.

False

Returns:

Type Description
tuple

A tuple containing the certified loss. If return_bounds is True, the tuple also contains the lower and upper bounds. If return_stats is True, the tuple also contains the robust error.

Source code in CTRAIN/train/certified/losses/ibp.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def get_ibp_loss(hardened_model, ptb, data, target, n_classes, criterion, return_bounds=False, return_stats=False):
    """
    Compute the Interval Bound Propagation (IBP) loss for a given model.

    Args:
        hardened_model (auto_LiRPA.BoundedModule): The bounded model to be trained.
        ptb (autoLiRPA.PerturbationLpNorm): The perturbation applied to the input data.
        data (torch.Tensor): Input data.
        target (torch.Tensor): Target labels.
        n_classes (int): Number of classes.
        criterion (callable): Loss function to be used.
        return_bounds (bool, optional): If True, return the lower and upper bounds. Default is False.
        return_stats (bool, optional): If True, return additional statistics. Default is False.

    Returns:
        (tuple): A tuple containing the certified loss. If `return_bounds` is True, the tuple also contains the lower and upper bounds. 
               If `return_stats` is True, the tuple also contains the robust error.
    """
    lb, ub = bound_ibp(
        model=hardened_model,
        ptb=ptb,
        data=data,
        target=target,
        n_classes=n_classes,
    )
    certified_loss = get_loss_from_bounds(lb, criterion)

    return_tuple = (certified_loss,)

    if return_bounds:
        return_tuple = return_tuple + (lb, ub)
    if return_stats:
        robust_err = torch.sum((lb < 0).any(dim=1)).item() / data.size(0)
        return_tuple = return_tuple + (robust_err,)

    return return_tuple