Skip to content

KCL

Amortized

Bases: nn.Module

Source code in sequel/algos/pytorch/kcl.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Amortized(nn.Module):
    def __init__(self, input_units: int, d_theta: int, output_units: int):
        """Inits the inference block used by the Kernel Continual Learning algorithm.

        Args:
            input_units (int): dimensionality of the input.
            d_theta (int): dimensionality of the intermediate hidden layers.
            output_units (int): dimensionality of the output.
        """
        super(Amortized, self).__init__()
        self.output_units = output_units
        self.weight_mean = InferenceBlock(input_units, d_theta, output_units)
        self.weight_log_variance = InferenceBlock(input_units, d_theta, output_units)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight_mean = self.weight_mean(x)
        weight_log_variance = self.weight_log_variance(x)
        return weight_mean, weight_log_variance

__init__(input_units, d_theta, output_units)

Inits the inference block used by the Kernel Continual Learning algorithm.

Parameters:

Name Type Description Default
input_units int

dimensionality of the input.

required
d_theta int

dimensionality of the intermediate hidden layers.

required
output_units int

dimensionality of the output.

required
Source code in sequel/algos/pytorch/kcl.py
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(self, input_units: int, d_theta: int, output_units: int):
    """Inits the inference block used by the Kernel Continual Learning algorithm.

    Args:
        input_units (int): dimensionality of the input.
        d_theta (int): dimensionality of the intermediate hidden layers.
        output_units (int): dimensionality of the output.
    """
    super(Amortized, self).__init__()
    self.output_units = output_units
    self.weight_mean = InferenceBlock(input_units, d_theta, output_units)
    self.weight_log_variance = InferenceBlock(input_units, d_theta, output_units)

InferenceBlock

Bases: nn.Module

Source code in sequel/algos/pytorch/kcl.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class InferenceBlock(nn.Module):
    def __init__(self, input_units: int, d_theta: int, output_units: int):
        """Inits the inference block used by the Kernel Continual Learning algorithm.

        Args:
            input_units (int): dimensionality of the input.
            d_theta (int): dimensionality of the intermediate hidden layers.
            output_units (int): dimensionality of the output.
        """
        super(InferenceBlock, self).__init__()
        self.module = nn.Sequential(
            nn.Linear(input_units, d_theta, bias=True),
            nn.ELU(inplace=True),
            nn.Linear(d_theta, d_theta, bias=True),
            nn.ELU(inplace=True),
            nn.Linear(d_theta, output_units, bias=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.module(x)

__init__(input_units, d_theta, output_units)

Inits the inference block used by the Kernel Continual Learning algorithm.

Parameters:

Name Type Description Default
input_units int

dimensionality of the input.

required
d_theta int

dimensionality of the intermediate hidden layers.

required
output_units int

dimensionality of the output.

required
Source code in sequel/algos/pytorch/kcl.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, input_units: int, d_theta: int, output_units: int):
    """Inits the inference block used by the Kernel Continual Learning algorithm.

    Args:
        input_units (int): dimensionality of the input.
        d_theta (int): dimensionality of the intermediate hidden layers.
        output_units (int): dimensionality of the output.
    """
    super(InferenceBlock, self).__init__()
    self.module = nn.Sequential(
        nn.Linear(input_units, d_theta, bias=True),
        nn.ELU(inplace=True),
        nn.Linear(d_theta, d_theta, bias=True),
        nn.ELU(inplace=True),
        nn.Linear(d_theta, output_units, bias=True),
    )

KCL

Bases: PytorchBaseAlgorithm

Kernel Continual Learning algorithm. The code is adapted from https://github.com/mmderakhshani/KCL/blob/main/stable_sgd/main.py

KCL is not yet implemented in JAX.

References

[1] Derakhshani, M. M., Zhen, X., Shao, L. & Snoek, C. Kernel Continual Learning. in Proceedings of the 38th International Conference on Machine Learning, ICML 2021.

Source code in sequel/algos/pytorch/kcl.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
class KCL(PytorchBaseAlgorithm):
    """Kernel Continual Learning algorithm. The code is adapted from https://github.com/mmderakhshani/KCL/blob/main/stable_sgd/main.py

    KCL is not yet implemented in JAX.

    References:
        [1] Derakhshani, M. M., Zhen, X., Shao, L. & Snoek, C. Kernel Continual Learning. in Proceedings of the 38th
            International Conference on Machine Learning, ICML 2021.
    """

    def __init__(
        self,
        lmd: float,
        core_size: int,
        d_rn_f: int,
        kernel_type: Literal["rbf", "rff", "linear", "poly"],
        tau: float,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.__check_valid__()

        self.core_size = core_size
        self.kernel_type = kernel_type
        self.tau = tau
        self.coresets = {}

        device = next(self.backbone.parameters()).device
        embedding = self.backbone.encoder(torch.ones(self.input_dimensions).unsqueeze(0).to(device))

        self.backbone = KernelBackboneWrapper(
            self.backbone, hiddens=embedding.numel(), lmd=lmd, num_tasks=self.num_tasks, d_rn_f=d_rn_f
        ).to(device)

    def __check_valid__(self):
        if getattr(self.backbone, "encoder", None) is None:
            raise AttributeError(
                "The backbone must have an encoder subnetwork to be compatible with the implementation of Kernel "
                "Continual Learning. The encoder consists of the entire original backbone except from the last Linear "
                "layer, i.e., the classifier."
            )

    def count_parameters(self) -> int:
        if not isinstance(self.backbone, KernelBackboneWrapper):
            return super().count_parameters()
        return sum([p.numel() for p in self.backbone.parameters() if p.requires_grad])

    def prepare_train_loader(self, task: int) -> DataLoader:
        """Splits the training dataset of the given `task` to training and coreset."""
        dataset = self.benchmark.get_train_dataset(task)
        dataset, coreset = random_split(dataset, lengths=[len(dataset) - self.core_size, self.core_size])
        self.coresets[task] = coreset
        self.register_coreset(coreset)
        return DataLoader(dataset, self.benchmark.batch_size, shuffle=True, **self.benchmark.dl_kwargs)

    def register_coreset(self, coreset):
        num_classes = self.benchmark.num_classes
        x = [sample[0] for sample in coreset]
        y = [sample[1] for sample in coreset]
        self.coreset_input = torch.stack(x).to(self.device)
        self.coreset_target = F.one_hot(torch.tensor(y), num_classes=num_classes).to(self.device).float()

    def forward(self):
        """Performs the forward for the Kernel Continual Learning backbone."""
        self.y_hat = self.backbone.forward(self.x, self.t, self.coreset_input, self.coreset_target)
        return self.y_hat

    def kl_div(self, m: torch.Tensor, log_v: torch.Tensor, m0: torch.Tensor, log_v0: torch.Tensor) -> torch.Tensor:
        """Computes the Kullback-Leibler divergence assuming two normal distributions parameterized by the arguments."""
        v = log_v.exp()
        v0 = log_v0.exp()

        dout, din = m.shape
        const_term = -0.5 * dout * din

        log_std_diff = 0.5 * torch.sum(torch.log(v0) - torch.log(v))
        mu_diff_term = 0.5 * torch.sum((v + (m0 - m) ** 2) / v0)
        kl = const_term + log_std_diff + mu_diff_term
        return kl

    def training_step(self, *args, **kwargs):
        self.optimizer_zero_grad()

        self.y_hat = self.forward()
        self.loss = F.cross_entropy(self.y_hat, self.y)

        if self.kernel_type == "rff":
            r_mu, r_log_var = self.backbone.r_mu, self.backbone.r_log_var
            p_mu, p_log_var = self.backbone.p_mu, self.backbone.p_log_var
            self.loss += self.tau * self.kl_div(r_mu, r_log_var, p_mu, p_log_var)
        self.loss.backward()
        self.optimizer.step()

    def on_before_val_epoch(self, *args, **kwargs):
        logging.info(f"Setting the coreset for validating task {self.current_val_task}.")
        self.register_coreset(self.coresets[self.current_val_task])
        return super().on_before_val_epoch(*args, **kwargs)

forward()

Performs the forward for the Kernel Continual Learning backbone.

Source code in sequel/algos/pytorch/kcl.py
237
238
239
240
def forward(self):
    """Performs the forward for the Kernel Continual Learning backbone."""
    self.y_hat = self.backbone.forward(self.x, self.t, self.coreset_input, self.coreset_target)
    return self.y_hat

kl_div(m, log_v, m0, log_v0)

Computes the Kullback-Leibler divergence assuming two normal distributions parameterized by the arguments.

Source code in sequel/algos/pytorch/kcl.py
242
243
244
245
246
247
248
249
250
251
252
253
def kl_div(self, m: torch.Tensor, log_v: torch.Tensor, m0: torch.Tensor, log_v0: torch.Tensor) -> torch.Tensor:
    """Computes the Kullback-Leibler divergence assuming two normal distributions parameterized by the arguments."""
    v = log_v.exp()
    v0 = log_v0.exp()

    dout, din = m.shape
    const_term = -0.5 * dout * din

    log_std_diff = 0.5 * torch.sum(torch.log(v0) - torch.log(v))
    mu_diff_term = 0.5 * torch.sum((v + (m0 - m) ** 2) / v0)
    kl = const_term + log_std_diff + mu_diff_term
    return kl

prepare_train_loader(task)

Splits the training dataset of the given task to training and coreset.

Source code in sequel/algos/pytorch/kcl.py
222
223
224
225
226
227
228
def prepare_train_loader(self, task: int) -> DataLoader:
    """Splits the training dataset of the given `task` to training and coreset."""
    dataset = self.benchmark.get_train_dataset(task)
    dataset, coreset = random_split(dataset, lengths=[len(dataset) - self.core_size, self.core_size])
    self.coresets[task] = coreset
    self.register_coreset(coreset)
    return DataLoader(dataset, self.benchmark.batch_size, shuffle=True, **self.benchmark.dl_kwargs)

KernelBackboneWrapper

Bases: BaseBackbone

Source code in sequel/algos/pytorch/kcl.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class KernelBackboneWrapper(BaseBackbone):
    def __init__(
        self,
        model: BaseBackbone,
        hiddens: int,
        lmd: float,
        num_tasks: int,
        d_rn_f: int,
        kernel_type: Literal["rbf", "rff", "linear", "poly"] = "rff",
    ):
        """Model Wrapper for Kernel Continual Learning. Extracts the encoder of the original backbone and performs the k
        ernel computations outlined in [1].


        Notes:
            The `hiddens` argument can be removed and instead inferred.

        Args:
            model (BaseBackbone): the original backbone. The model must have an encoder component.
            hiddens (int): the dimensionality of the hidden dimensions for the kernel-specific modules.
            lmd (float): The initial value for the lmd Parameter.
            num_tasks (int): the number of tasks to be solved.
            d_rn_f (int): dimensionality of the Random Fourier Features (RFFs). Applicable only if `kernel_type` is 'rff'.
            kernel_type (str, optional): _description_. Defaults to "rbf".
        """
        multihead, classes_per_task, masking_value = model.multihead, model.classes_per_task, model.masking_value
        super().__init__(multihead=multihead, classes_per_task=classes_per_task, masking_value=masking_value)
        self.encoder = model.encoder
        self.d_rn_f = d_rn_f

        self.post = Amortized(hiddens, hiddens, hiddens)
        self.prior = Amortized(hiddens, hiddens, hiddens)

        device = next(model.parameters()).device
        self.lmd = nn.Parameter(torch.tensor([lmd for _ in range(num_tasks)])).to(device)
        self.gma = nn.Parameter(torch.tensor([1.0 for _ in range(num_tasks)])).to(device)
        self.bta = nn.Parameter(torch.tensor([0.0 for _ in range(num_tasks)])).to(device)
        self.kernel_type = kernel_type
        self.bias = 2 * math.pi * torch.rand(d_rn_f, 1).to(device)

    def inner_forward(self, x: torch.Tensor, post: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        out = self.encoder(x)
        out_features = self.normalize(out)
        out_mean = torch.mean(out_features, dim=0, keepdim=True)
        if post:
            mu, logvar = self.post(out_mean)
        else:
            mu, logvar = self.prior(out_mean)
        return out_features, mu, logvar

    def kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        if self.kernel_type == "rbf":
            support_kernel = torch.exp(-0.25 * torch.norm(x.unsqueeze(1) - y, dim=2, p=1))
        elif self.kernel_type == "linear":
            support_kernel = x @ y.T
        elif self.kernel_type == "poly":
            support_kernel = (torch.matmul(x, y.T) + 1).pow(3)
        elif self.kernel_type == "rff":
            support_kernel = x.T @ y
        else:
            raise Exception(f"Unknown kenrel. Only support RBF, RFF, POLY, LIN.")
        return support_kernel

    @staticmethod
    def sample(mu: torch.Tensor, logvar: torch.Tensor, L: int, device) -> torch.Tensor:
        shape = (L,) + mu.size()
        eps = torch.randn(shape).to(device)
        return mu.unsqueeze(0) + eps * logvar.exp().sqrt().unsqueeze(0)

    def rand_features(self, bases: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
        return math.sqrt(2 / self.bias.shape[0]) * torch.cos(torch.matmul(bases, features) + self.bias)

    def compute_kernels(
        self, features_train: torch.Tensor, features_coreset: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        device = features_coreset.device
        if self.kernel_type == "rff":
            # project to random features
            rs = self.sample(self.r_mu, self.r_log_var, self.d_rn_f, device).squeeze()
            features_coreset = self.rand_features(rs, torch.transpose(features_coreset, 1, 0))
            features_train = self.rand_features(rs, torch.transpose(features_train, 1, 0))

        support_kernel = self.kernel(features_coreset, features_coreset)
        cross_kernel = self.kernel(features_coreset, features_train)
        return support_kernel, cross_kernel

    def forward(
        self, x: torch.Tensor, task_ids: torch.Tensor, coreset_input: torch.Tensor, coreset_target: torch.Tensor
    ) -> torch.Tensor:
        current_task = torch.unique(task_ids)
        assert len(current_task) == 1
        features_train, self.p_mu, self.p_log_var = self.inner_forward(x, post=False)
        features_coreset, self.r_mu, self.r_log_var = self.inner_forward(coreset_input, post=True)

        support_kernel, cross_kernel = self.compute_kernels(features_train, features_coreset)

        alpha = torch.matmul(
            torch.inverse(
                support_kernel
                + (torch.abs(self.lmd[current_task - 1]) + 0.01) * torch.eye(support_kernel.shape[0]).to(x.device)
            ),
            coreset_target,
        )

        out = self.gma[current_task - 1] * torch.matmul(cross_kernel.T, alpha) + self.bta[current_task - 1]

        if self.multihead:
            out = self.select_output_head(out, task_ids)

        return out

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        max_val = x.max()
        min_val = x.min()
        return (x - min_val) / (max_val - min_val)

__init__(model, hiddens, lmd, num_tasks, d_rn_f, kernel_type='rff')

Model Wrapper for Kernel Continual Learning. Extracts the encoder of the original backbone and performs the k ernel computations outlined in [1].

Notes

The hiddens argument can be removed and instead inferred.

Parameters:

Name Type Description Default
model BaseBackbone

the original backbone. The model must have an encoder component.

required
hiddens int

the dimensionality of the hidden dimensions for the kernel-specific modules.

required
lmd float

The initial value for the lmd Parameter.

required
num_tasks int

the number of tasks to be solved.

required
d_rn_f int

dimensionality of the Random Fourier Features (RFFs). Applicable only if kernel_type is 'rff'.

required
kernel_type str

description. Defaults to "rbf".

'rff'
Source code in sequel/algos/pytorch/kcl.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    model: BaseBackbone,
    hiddens: int,
    lmd: float,
    num_tasks: int,
    d_rn_f: int,
    kernel_type: Literal["rbf", "rff", "linear", "poly"] = "rff",
):
    """Model Wrapper for Kernel Continual Learning. Extracts the encoder of the original backbone and performs the k
    ernel computations outlined in [1].


    Notes:
        The `hiddens` argument can be removed and instead inferred.

    Args:
        model (BaseBackbone): the original backbone. The model must have an encoder component.
        hiddens (int): the dimensionality of the hidden dimensions for the kernel-specific modules.
        lmd (float): The initial value for the lmd Parameter.
        num_tasks (int): the number of tasks to be solved.
        d_rn_f (int): dimensionality of the Random Fourier Features (RFFs). Applicable only if `kernel_type` is 'rff'.
        kernel_type (str, optional): _description_. Defaults to "rbf".
    """
    multihead, classes_per_task, masking_value = model.multihead, model.classes_per_task, model.masking_value
    super().__init__(multihead=multihead, classes_per_task=classes_per_task, masking_value=masking_value)
    self.encoder = model.encoder
    self.d_rn_f = d_rn_f

    self.post = Amortized(hiddens, hiddens, hiddens)
    self.prior = Amortized(hiddens, hiddens, hiddens)

    device = next(model.parameters()).device
    self.lmd = nn.Parameter(torch.tensor([lmd for _ in range(num_tasks)])).to(device)
    self.gma = nn.Parameter(torch.tensor([1.0 for _ in range(num_tasks)])).to(device)
    self.bta = nn.Parameter(torch.tensor([0.0 for _ in range(num_tasks)])).to(device)
    self.kernel_type = kernel_type
    self.bias = 2 * math.pi * torch.rand(d_rn_f, 1).to(device)