Skip to content

crown_ibp

get_crown_ibp_loss(hardened_model, ptb, data, target, n_classes, criterion, beta, return_bounds=False, return_stats=True)

Compute the CROWN-IBP loss.

Parameters:

Name Type Description Default
hardened_model BoundedModule

The bounded model to be trained.

required
ptb PerturbationLpNorm

The perturbation applied to the input data.

required
data Tensor

The input data.

required
target Tensor

The target labels.

required
n_classes int

The number of classes in the classification task.

required
criterion callable

The loss function to be used.

required
beta float

The interpolation parameter between CROWN_IBP and IBP bounds.

required
return_bounds bool

If True, return the lower bounds. Default is False.

False
return_stats bool

If True, return the robust error statistics. Default is True.

True

Returns:

Type Description
tuple

A tuple containing the certified loss. If return_bounds is True, the tuple also contains the lower bounds. If return_stats is True, the tuple also contains the robust error statistics.

Source code in CTRAIN/train/certified/losses/crown_ibp.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def get_crown_ibp_loss(hardened_model, ptb, data, target, n_classes, criterion, beta, return_bounds=False, return_stats=True):
    """
    Compute the CROWN-IBP loss.

    Parameters:
        hardened_model (auto_LiRPA.BoundedModule): The bounded model to be trained.
        ptb (autoLiRPA.PerturbationLpNorm): The perturbation applied to the input data.
        data (torch.Tensor): The input data.
        target (torch.Tensor): The target labels.
        n_classes (int): The number of classes in the classification task.
        criterion (callable): The loss function to be used.
        beta (float): The interpolation parameter between CROWN_IBP and IBP bounds.
        return_bounds (bool, optional): If True, return the lower bounds. Default is False.
        return_stats (bool, optional): If True, return the robust error statistics. Default is True.

    Returns:
        (tuple): A tuple containing the certified loss. If return_bounds is True, the tuple also contains the lower bounds.
            If return_stats is True, the tuple also contains the robust error statistics.
    """
    ilb, iub = bound_ibp(
        model=hardened_model,
        ptb=ptb,
        data=data,
        target=target,
        n_classes=n_classes,
        bound_upper=False
    )
    if beta < 1e-5:
        lb = ilb
    else:
        # Attention: We have to reuse the input here. Otherwise the memory requirements become too large!
        # Input is reused from above bound_ibp call!
        clb, cub = bound_crown_ibp(
            model=hardened_model,
            ptb=ptb,
            data=data,
            target=target,
            n_classes=n_classes,
            reuse_input=True,
            bound_upper=False
        )

        lb = clb * beta + ilb * (1 - beta)

    certified_loss = get_loss_from_bounds(lb, criterion)

    return_tuple = (certified_loss,)

    if return_bounds:
        return_tuple = return_tuple + (lb, None)
    if return_stats:
        robust_err = torch.sum((lb < 0).any(dim=1)).item() / data.size(0)
        return_tuple = return_tuple + (robust_err,)

    return return_tuple