Skip to content

pgd

pgd_attack(model, data, target, x_L, x_U, restarts=1, step_size=0.2, n_steps=200, early_stopping=True, device='cuda', decay_factor=0.1, decay_checkpoints=())

Performs the Projected Gradient Descent (PGD) attack on the given model and data.

Parameters:

Name Type Description Default
model Module

The neural network model to attack.

required
data Tensor

The input data to perturb.

required
target Tensor

The target labels for the input data.

required
x_L Tensor

The lower bound of the input data.

required
x_U Tensor

The upper bound of the input data.

required
restarts int

The number of random restarts. Default is 1.

1
step_size float

The step size for each gradient update. Default is 0.2.

0.2
n_steps int

The number of steps for the attack. Default is 200.

200
early_stopping bool

Whether to stop early if adversarial examples are found. Default is True.

True
device str

The device to perform the attack on. Default is 'cuda'.

'cuda'
decay_factor float

The factor by which to decay the step size at each checkpoint. Default is 0.1.

0.1
decay_checkpoints tuple

The checkpoints at which to decay the step size. Default is ().

()

Returns:

Type Description

torch.Tensor: The generated adversarial examples.

Source code in CTRAIN/attacks/pgd.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def pgd_attack(model, data, target, x_L, x_U, restarts=1, step_size=.2, n_steps=200, early_stopping=True, device='cuda', decay_factor=.1, decay_checkpoints=()):
    """
    Performs the Projected Gradient Descent (PGD) attack on the given model and data.

    Args:
        model (torch.nn.Module): The neural network model to attack.
        data (torch.Tensor): The input data to perturb.
        target (torch.Tensor): The target labels for the input data.
        x_L (torch.Tensor): The lower bound of the input data.
        x_U (torch.Tensor): The upper bound of the input data.
        restarts (int, optional): The number of random restarts. Default is 1.
        step_size (float, optional): The step size for each gradient update. Default is 0.2.
        n_steps (int, optional): The number of steps for the attack. Default is 200.
        early_stopping (bool, optional): Whether to stop early if adversarial examples are found. Default is True.
        device (str, optional): The device to perform the attack on. Default is 'cuda'.
        decay_factor (float, optional): The factor by which to decay the step size at each checkpoint. Default is 0.1.
        decay_checkpoints (tuple, optional): The checkpoints at which to decay the step size. Default is ().

    Returns:
        torch.Tensor: The generated adversarial examples.
    """
    x_L, x_U = x_L.to(device), x_U.to(device)
    if data is None:
        data = ((x_L + x_U) / 2).to(device)

    lr_scale = torch.max((x_U-x_L)/2)

    adversarial_examples = data.detach().clone()
    example_found = torch.zeros(data.shape[0], dtype=torch.bool, device=device)
    best_loss = torch.ones(data.shape[0], dtype=torch.float32, device=device)*(-np.inf)

    # TODO: Also support margin loss (although not used in TAPS/SABR/MTL-IBP)
    loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    for restart_idx in range(restarts):

        if early_stopping and example_found.all():
            break

        random_noise = (x_L + torch.rand(data.shape, device=device) * (x_U - x_L)).to(device)
        attack_input = data.detach().clone().to(device) + random_noise            

        grad_cleaner = optim.SGD([attack_input], lr=1e-3)
        with torch.enable_grad():
            for step in range(n_steps):
                grad_cleaner.zero_grad()

                if early_stopping:
                    attack_input = attack_input[~example_found]

                attack_input.requires_grad = True

                model_out = model(attack_input)

                loss = loss_fn(model_out, target)

                loss.sum().backward(retain_graph=False)

                if len(decay_checkpoints) > 0:
                    no_passed_checkpoints = len([checkpoint for checkpoint in decay_checkpoints if step >= checkpoint])
                    decay = decay_factor ** no_passed_checkpoints
                else:
                    decay = 1

                step_input_change = step_size * lr_scale * decay * attack_input.grad.data.sign()

                attack_input = torch.clamp(attack_input.detach() + step_input_change, x_L, x_U)
                adv_out = model(attack_input)

                adv_loss = loss_fn(adv_out, target)

                if early_stopping:
                    improvement_idx = adv_loss > best_loss[~example_found]
                    best_loss[~example_found & improvement_idx] = adv_loss[improvement_idx].detach()
                    adversarial_examples[~example_found & improvement_idx] = attack_input[improvement_idx].detach()

                    example_found[~example_found][~torch.argmax(adv_out.detach(), dim=1).eq(target)] = True

                else:
                    improvement_idx = adv_loss > best_loss
                    best_loss[improvement_idx] = adv_loss[improvement_idx].detach()
                    adversarial_examples[improvement_idx] = attack_input[improvement_idx].detach()

                    example_found[~torch.argmax(adv_out.detach(), dim=1).eq(target)] = True

                if early_stopping and example_found.all():
                    break

    return adversarial_examples.detach()