Skip to content

util

split_network(model, block_sizes, network_input, device)

Splits a neural network model into smaller sequential blocks based on specified block sizes. Needed for TAPS/STAPS. Args: model (torch.nn.Module): The neural network model to be split. block_sizes (list of int): A list of integers specifying the sizes of each block. network_input (torch.Tensor): The input tensor to the network. device (torch.device): The device to which the tensors should be moved (e.g., 'cpu' or 'cuda'). Returns: list of torch.nn.Sequential: A list of sequential blocks representing the split network.

Source code in CTRAIN/train/certified/util.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def split_network(model, block_sizes, network_input, device):
    """
    Splits a neural network model into smaller sequential blocks based on specified block sizes. Needed for TAPS/STAPS.
    Args:
        model (torch.nn.Module): The neural network model to be split.
        block_sizes (list of int): A list of integers specifying the sizes of each block.
        network_input (torch.Tensor): The input tensor to the network.
        device (torch.device): The device to which the tensors should be moved (e.g., 'cpu' or 'cuda').
    Returns:
        list of torch.nn.Sequential: A list of sequential blocks representing the split network.
    """
    # TODO: Add assertions for robustness
    start = 0
    original_blocks = []
    network_input = network_input.to(device)
    for size in block_sizes:
        end = start + size
        abs_block = nn.Sequential(model.layers[start:end])
        original_blocks.append(abs_block)

        output_shape = abs_block(network_input).shape
        network_input = torch.zeros(output_shape).to(device)

        start = end
    return original_blocks