Skip to content

util

get_loss_from_bounds(lb, criterion)

Computes the certified loss from the given lower bounds using the specified criterion.

Parameters:

Name Type Description Default
lb Tensor

A tensor containing the lower bounds with shape (batch_size, num_classes - 1).

required
criterion callable

A loss function that takes two arguments: the input tensor and the target tensor.

required

Returns:

Type Description

torch.Tensor: The mean certified loss computed from the lower bounds.

Source code in CTRAIN/train/certified/losses/util.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def get_loss_from_bounds(lb, criterion):
    """
    Computes the certified loss from the given lower bounds using the specified criterion.

    Args:
        lb (torch.Tensor): A tensor containing the lower bounds with shape (batch_size, num_classes - 1).
        criterion (callable): A loss function that takes two arguments: the input tensor and the target tensor.

    Returns:
        torch.Tensor: The mean certified loss computed from the lower bounds.
    """
    lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)
    fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)
    certified_loss = criterion(-lb_padded, fake_labels).mean()
    return certified_loss