Skip to content

eps_scheduler

BaseScheduler

Source code in CTRAIN/train/certified/eps_scheduler.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class BaseScheduler():

    def __init__(self, num_epochs, eps, mean, std, start_eps=0, start_kappa=1, end_kappa=0, start_beta=0, end_beta=0, eps_schedule_unit='batch', eps_schedule=(0, 20), batches_per_epoch=None, start_epoch=-1):
        """
        Initializes the Base EpsScheduler.

        Args:
            num_epochs (int): The number of epochs for training.
            eps (float): The epsilon value for the scheduler.
            mean (float): The mean value for normalization.
            std (float): The standard deviation value for normalization.
            start_eps (float, optional): The starting epsilon value. Defaults to 0.
            start_kappa (float, optional): The starting kappa value. Defaults to 1.
            end_kappa (float, optional): The ending kappa value. Defaults to 0.
            start_beta (float, optional): The starting beta value. Defaults to 0.
            end_beta (float, optional): The ending beta value. Defaults to 0.
            eps_schedule_unit (str, optional): The unit for epsilon scheduling ('batch' or 'epoch'). Defaults to 'batch'.
            eps_schedule (tuple, optional): The schedule for epsilon values. Defaults to (0, 20).
            batches_per_epoch (int, optional): The number of batches per epoch. Defaults to None.
            start_epoch (int, optional): The starting epoch number. Defaults to -1.

        Raises:
            AssertionError: If num_epochs is None and eps_schedule_unit is 'epoch'.
            AssertionError: If the length of eps_schedule is not 2 or 3.
            AssertionError: If num_epochs is incompatible with eps_schedule.
        """

        if num_epochs is None and eps_schedule_unit=='epoch':
            num_epochs = sum(eps_schedule)
        elif num_epochs is None:
            assert False, "Please provide number of epochs!"
        if eps_schedule_unit=='epoch':
            if len(eps_schedule) == 3:
                assert num_epochs == sum(eps_schedule), "Eps Schedule is incompatible with specified number of epochs. Please adjust!"
            elif len(eps_schedule)==2:
                assert num_epochs >= sum(eps_schedule), "Eps Schedule is incompatible with specified number of epochs. Please adjust!"
            else:
                assert False, "Eps Schedule is incompatible with specified number of epochs. Please adjust!"

        self.num_epochs = num_epochs
        if len(eps_schedule) == 2:
            self.warm_up, self.ramp_up = eps_schedule
        elif len(eps_schedule) == 3:
            self.warm_up, self.ramp_up, _ = eps_schedule

        print(self.warm_up, self.ramp_up)
        self.cur_eps = start_eps
        self.cur_kappa = self.start_kappa = start_kappa
        self.end_kappa = end_kappa
        self.eps = eps
        self.start_eps = start_eps
        self.batches_per_epoch = batches_per_epoch
        self.start_beta = self.cur_beta = start_beta
        self.end_beta = end_beta
        self.mean = mean
        self.std = std

        if eps_schedule_unit == 'epoch':
            self.warm_up *= batches_per_epoch
            self.ramp_up *= batches_per_epoch

        self.training_steps = num_epochs * batches_per_epoch

        self.no_batches = 0

        if start_epoch > 0:
            self.no_batches = self.batches_per_epoch * start_epoch - 1            

    def get_cur_eps(self, normalise=True):
        """
        Get the current epsilon value, optionally normalised.

        Args:
            normalise (bool): If True, the returned epsilon value will be normalised by the standard deviation.

        Returns:
            torch.Tensor: The current epsilon value, normalised if specified.

        Notes:
            - The method checks for numerical instabilities and adjusts the current epsilon value if necessary.
        """
        # Check needed to mitigate numerical instabilities
        if (torch.tensor(self.get_max_eps(normalise=False) - self.cur_eps)  < 1e-7).all():
            self.cur_eps = self.get_max_eps(normalise=False)
        return torch.tensor(self.cur_eps) / torch.tensor(self.std) if normalise else self.cur_eps

    def get_cur_kappa(self):
        return self.cur_kappa

    def get_cur_beta(self):
        return self.cur_beta

    def get_max_eps(self, normalise=True):
        return torch.tensor(self.eps) / torch.tensor(self.std) if normalise else self.eps

    def batch_step(self, ):
        raise NotImplementedError

__init__(num_epochs, eps, mean, std, start_eps=0, start_kappa=1, end_kappa=0, start_beta=0, end_beta=0, eps_schedule_unit='batch', eps_schedule=(0, 20), batches_per_epoch=None, start_epoch=-1)

Initializes the Base EpsScheduler.

Parameters:

Name Type Description Default
num_epochs int

The number of epochs for training.

required
eps float

The epsilon value for the scheduler.

required
mean float

The mean value for normalization.

required
std float

The standard deviation value for normalization.

required
start_eps float

The starting epsilon value. Defaults to 0.

0
start_kappa float

The starting kappa value. Defaults to 1.

1
end_kappa float

The ending kappa value. Defaults to 0.

0
start_beta float

The starting beta value. Defaults to 0.

0
end_beta float

The ending beta value. Defaults to 0.

0
eps_schedule_unit str

The unit for epsilon scheduling ('batch' or 'epoch'). Defaults to 'batch'.

'batch'
eps_schedule tuple

The schedule for epsilon values. Defaults to (0, 20).

(0, 20)
batches_per_epoch int

The number of batches per epoch. Defaults to None.

None
start_epoch int

The starting epoch number. Defaults to -1.

-1

Raises:

Type Description
AssertionError

If num_epochs is None and eps_schedule_unit is 'epoch'.

AssertionError

If the length of eps_schedule is not 2 or 3.

AssertionError

If num_epochs is incompatible with eps_schedule.

Source code in CTRAIN/train/certified/eps_scheduler.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(self, num_epochs, eps, mean, std, start_eps=0, start_kappa=1, end_kappa=0, start_beta=0, end_beta=0, eps_schedule_unit='batch', eps_schedule=(0, 20), batches_per_epoch=None, start_epoch=-1):
    """
    Initializes the Base EpsScheduler.

    Args:
        num_epochs (int): The number of epochs for training.
        eps (float): The epsilon value for the scheduler.
        mean (float): The mean value for normalization.
        std (float): The standard deviation value for normalization.
        start_eps (float, optional): The starting epsilon value. Defaults to 0.
        start_kappa (float, optional): The starting kappa value. Defaults to 1.
        end_kappa (float, optional): The ending kappa value. Defaults to 0.
        start_beta (float, optional): The starting beta value. Defaults to 0.
        end_beta (float, optional): The ending beta value. Defaults to 0.
        eps_schedule_unit (str, optional): The unit for epsilon scheduling ('batch' or 'epoch'). Defaults to 'batch'.
        eps_schedule (tuple, optional): The schedule for epsilon values. Defaults to (0, 20).
        batches_per_epoch (int, optional): The number of batches per epoch. Defaults to None.
        start_epoch (int, optional): The starting epoch number. Defaults to -1.

    Raises:
        AssertionError: If num_epochs is None and eps_schedule_unit is 'epoch'.
        AssertionError: If the length of eps_schedule is not 2 or 3.
        AssertionError: If num_epochs is incompatible with eps_schedule.
    """

    if num_epochs is None and eps_schedule_unit=='epoch':
        num_epochs = sum(eps_schedule)
    elif num_epochs is None:
        assert False, "Please provide number of epochs!"
    if eps_schedule_unit=='epoch':
        if len(eps_schedule) == 3:
            assert num_epochs == sum(eps_schedule), "Eps Schedule is incompatible with specified number of epochs. Please adjust!"
        elif len(eps_schedule)==2:
            assert num_epochs >= sum(eps_schedule), "Eps Schedule is incompatible with specified number of epochs. Please adjust!"
        else:
            assert False, "Eps Schedule is incompatible with specified number of epochs. Please adjust!"

    self.num_epochs = num_epochs
    if len(eps_schedule) == 2:
        self.warm_up, self.ramp_up = eps_schedule
    elif len(eps_schedule) == 3:
        self.warm_up, self.ramp_up, _ = eps_schedule

    print(self.warm_up, self.ramp_up)
    self.cur_eps = start_eps
    self.cur_kappa = self.start_kappa = start_kappa
    self.end_kappa = end_kappa
    self.eps = eps
    self.start_eps = start_eps
    self.batches_per_epoch = batches_per_epoch
    self.start_beta = self.cur_beta = start_beta
    self.end_beta = end_beta
    self.mean = mean
    self.std = std

    if eps_schedule_unit == 'epoch':
        self.warm_up *= batches_per_epoch
        self.ramp_up *= batches_per_epoch

    self.training_steps = num_epochs * batches_per_epoch

    self.no_batches = 0

    if start_epoch > 0:
        self.no_batches = self.batches_per_epoch * start_epoch - 1            

get_cur_eps(normalise=True)

Get the current epsilon value, optionally normalised.

Parameters:

Name Type Description Default
normalise bool

If True, the returned epsilon value will be normalised by the standard deviation.

True

Returns:

Type Description

torch.Tensor: The current epsilon value, normalised if specified.

Notes
  • The method checks for numerical instabilities and adjusts the current epsilon value if necessary.
Source code in CTRAIN/train/certified/eps_scheduler.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def get_cur_eps(self, normalise=True):
    """
    Get the current epsilon value, optionally normalised.

    Args:
        normalise (bool): If True, the returned epsilon value will be normalised by the standard deviation.

    Returns:
        torch.Tensor: The current epsilon value, normalised if specified.

    Notes:
        - The method checks for numerical instabilities and adjusts the current epsilon value if necessary.
    """
    # Check needed to mitigate numerical instabilities
    if (torch.tensor(self.get_max_eps(normalise=False) - self.cur_eps)  < 1e-7).all():
        self.cur_eps = self.get_max_eps(normalise=False)
    return torch.tensor(self.cur_eps) / torch.tensor(self.std) if normalise else self.cur_eps

LinearScheduler

Bases: BaseScheduler

A scheduler that linearly adjusts epsilon, kappa, and beta values over a specified number of epochs.

Parameters:

Name Type Description Default
num_epochs int

Total number of epochs for training.

required
eps float

The target epsilon value.

required
mean float

The mean value for normalization.

required
std float

The standard deviation value for normalization.

required
start_eps float

The starting epsilon value. Defaults to 0.

0
start_kappa float

The starting kappa value. Defaults to 1.

1
end_kappa float

The ending kappa value. Defaults to 0.

0
start_beta float

The starting beta value. Defaults to 1.

1
end_beta float

The ending beta value. Defaults to 0.

0
eps_schedule_unit str

The unit for epsilon scheduling ('batch' or 'epoch'). Defaults to 'batch'.

'batch'
eps_schedule tuple

The schedule for epsilon adjustment. Defaults to (0, 20).

(0, 20)
batches_per_epoch int

Number of batches per epoch. Defaults to None.

None
start_epoch int

The epoch to start the scheduler. Defaults to -1.

-1

Methods:

Name Description
batch_step

Adjusts the current epsilon, kappa, and beta values based on the current batch number.

Source code in CTRAIN/train/certified/eps_scheduler.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class LinearScheduler(BaseScheduler):
    """
    A scheduler that linearly adjusts epsilon, kappa, and beta values over a specified number of epochs.

    Args:
        num_epochs (int): Total number of epochs for training.
        eps (float): The target epsilon value.
        mean (float): The mean value for normalization.
        std (float): The standard deviation value for normalization.
        start_eps (float, optional): The starting epsilon value. Defaults to 0.
        start_kappa (float, optional): The starting kappa value. Defaults to 1.
        end_kappa (float, optional): The ending kappa value. Defaults to 0.
        start_beta (float, optional): The starting beta value. Defaults to 1.
        end_beta (float, optional): The ending beta value. Defaults to 0.
        eps_schedule_unit (str, optional): The unit for epsilon scheduling ('batch' or 'epoch'). Defaults to 'batch'.
        eps_schedule (tuple, optional): The schedule for epsilon adjustment. Defaults to (0, 20).
        batches_per_epoch (int, optional): Number of batches per epoch. Defaults to None.
        start_epoch (int, optional): The epoch to start the scheduler. Defaults to -1.

    Methods:
        batch_step():
            Adjusts the current epsilon, kappa, and beta values based on the current batch number.
    """
    def __init__(self, num_epochs, eps, mean, std,start_eps=0, start_kappa=1, end_kappa=0, start_beta=1, end_beta=0, eps_schedule_unit='batch', eps_schedule=(0, 20), batches_per_epoch=None, start_epoch=-1):
        super().__init__(
            num_epochs=num_epochs, 
            eps=eps, 
            mean=mean, 
            std=std,
            start_eps=start_eps, 
            start_kappa=start_kappa, 
            end_kappa=end_kappa, 
            eps_schedule_unit=eps_schedule_unit, 
            eps_schedule=eps_schedule, 
            batches_per_epoch=batches_per_epoch, 
            start_beta=start_beta, 
            end_beta=end_beta
        )

        if start_epoch > 0:
            self.batch_step()

    def batch_step(self):
        if self.warm_up < self.no_batches < (self.warm_up + self.ramp_up):
            self.cur_eps += (self.eps / self.ramp_up)
            kappa_step = (self.start_kappa - self.end_kappa) / self.ramp_up
            self.cur_kappa -= kappa_step
            beta_step = (self.start_beta - self.end_beta) / self.ramp_up
            self.cur_beta -= beta_step
        self.cur_eps = min(self.cur_eps, self.eps)
        self.cur_kappa = max(self.cur_kappa, self.end_kappa)
        self.no_batches += 1

SmoothedScheduler

Bases: BaseScheduler

A scheduler that smoothly transitions epsilon, kappa, and beta values over the course of training.

Parameters:

Name Type Description Default
num_epochs int

Number of epochs for training.

required
eps float

Final epsilon value.

required
mean float

Mean value for normalization.

required
std float

Standard deviation value for normalization.

required
start_eps float

Initial epsilon value. Default is 0.

0
start_kappa float

Initial kappa value. Default is 1.

1
end_kappa float

Final kappa value. Default is 0.

0
start_beta float

Initial beta value. Default is 1.

1
end_beta float

Final beta value. Default is 0.

0
eps_schedule_unit str

Unit for epsilon scheduling ('batch' or 'epoch'). Default is 'batch'.

'batch'
batches_per_epoch int

Number of batches per epoch. Required if eps_schedule_unit is 'batch'.

None
start_epoch int

Epoch to start the scheduling. Default is -1.

-1
eps_schedule tuple

Tuple indicating the start and end of epsilon scheduling. Default is (0, 20).

(0, 20)
midpoint float

Midpoint for the transition from exponential to linear schedule. Default is 0.25.

0.25
exponent float

Exponent for the exponential schedule. Default is 4.0.

4.0

Methods:

Name Description
batch_step

Updates the current epsilon, kappa, and beta values based on the current batch number.

Source code in CTRAIN/train/certified/eps_scheduler.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class SmoothedScheduler(BaseScheduler):
    """
    A scheduler that smoothly transitions epsilon, kappa, and beta values over the course of training.

    Args:
        num_epochs (int): Number of epochs for training.
        eps (float): Final epsilon value.
        mean (float): Mean value for normalization.
        std (float): Standard deviation value for normalization.
        start_eps (float, optional): Initial epsilon value. Default is 0.
        start_kappa (float, optional): Initial kappa value. Default is 1.
        end_kappa (float, optional): Final kappa value. Default is 0.
        start_beta (float, optional): Initial beta value. Default is 1.
        end_beta (float, optional): Final beta value. Default is 0.
        eps_schedule_unit (str, optional): Unit for epsilon scheduling ('batch' or 'epoch'). Default is 'batch'.
        batches_per_epoch (int, optional): Number of batches per epoch. Required if eps_schedule_unit is 'batch'.
        start_epoch (int, optional): Epoch to start the scheduling. Default is -1.
        eps_schedule (tuple, optional): Tuple indicating the start and end of epsilon scheduling. Default is (0, 20).
        midpoint (float, optional): Midpoint for the transition from exponential to linear schedule. Default is 0.25.
        exponent (float, optional): Exponent for the exponential schedule. Default is 4.0.

    Methods:
        batch_step():
            Updates the current epsilon, kappa, and beta values based on the current batch number.
    """
    def __init__(self, num_epochs, eps, mean, std, start_eps=0, start_kappa=1, end_kappa=0, start_beta=1, end_beta=0, eps_schedule_unit='batch', batches_per_epoch=None, start_epoch=-1, eps_schedule=(0, 20), midpoint=.25, exponent=4.0):
        super().__init__(
            num_epochs=num_epochs, 
            eps=eps, 
            mean=mean,
            std=std,
            start_eps=start_eps, 
            start_kappa=start_kappa, 
            end_kappa=end_kappa, 
            eps_schedule_unit=eps_schedule_unit, 
            eps_schedule=eps_schedule, 
            batches_per_epoch=batches_per_epoch, 
            start_epoch=start_epoch,
            start_beta=start_beta, 
            end_beta=end_beta)
        self.midpoint = midpoint
        self.exponent = exponent
        if start_epoch > 0:
            self.batch_step()


    def batch_step(self):
        init_value = self.start_eps
        final_value = self.eps
        beta = self.exponent
        step = self.no_batches
        # Batch number for schedule start
        init_step = self.warm_up + 1
        # Batch number for schedule end
        final_step = self.warm_up + self.ramp_up
        # Batch number for switching from exponential to linear schedule
        mid_step = int((final_step - init_step) * self.midpoint) + init_step
        t = (mid_step - init_step) ** (beta - 1.)
        # find coefficient for exponential growth, such that at mid point the gradient is the same as a linear ramp to final value
        alpha = (final_value - init_value) / ((final_step - mid_step) * beta * t + (mid_step - init_step) * t)
        # value at switching point
        mid_value = init_value + alpha * (mid_step - init_step) ** beta
        # return init_value when we have not started
        is_ramp = float(step > init_step)
        # linear schedule after mid step
        is_linear = float(step >= mid_step)
        exp_value = init_value + alpha * float(step - init_step) ** beta
        linear_value = min(mid_value + (final_value - mid_value) * (step - mid_step) / (final_step - mid_step), final_value)
        self.cur_eps = is_ramp * ((1.0 - is_linear) * exp_value + is_linear * linear_value) + (1.0 - is_ramp) * init_value
        self.cur_kappa = self.start_kappa * (1 - (self.cur_eps * (1-self.end_kappa)) / self.eps)
        self.cur_beta = self.start_beta * (1 - (self.cur_eps * (1-self.end_beta)) / self.eps)
        self.cur_eps = min(self.cur_eps, self.eps)
        self.cur_kappa = max(self.cur_kappa, self.end_kappa)
        self.cur_beta = max(self.cur_beta, self.end_beta)
        self.no_batches += 1