Skip to content

EWC

EWC

Bases: PytorchRegularizationBaseAlgorithm

Elastic Weight Consolidation Algorithm Class. Inherits from BaseAlgorithm.

The equivalent JAX implementation is EWC in JAX.

References

[1] Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural networks. PNAS 114, 3521-3526 (2017).

Source code in sequel/algos/pytorch/ewc.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class EWC(PytorchRegularizationBaseAlgorithm):
    """Elastic Weight Consolidation Algorithm Class. Inherits from BaseAlgorithm.

    The equivalent JAX implementation is [`EWC in JAX`][sequel.algos.jax.ewc.EWC].

    References:
        [1] Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural networks. PNAS 114, 3521-3526 (2017).
    """

    def __init__(self, ewc_lambda: float = 1.0, *args, **kwargs):
        """Inits the Elastic Weight Consolidation algorithm.

        Args:
            ewc_lambda (float): The lambda coefficient of EWC algorithm.
        """
        super().__init__(regularization_coefficient=ewc_lambda, *args, **kwargs)

    def __repr__(self) -> str:
        return f"EWC(ewc_lambda={self.regularization_coefficient})"

    def calculate_parameter_importance(self):
        train_loader = self.benchmark.train_dataloader(self.task_counter)
        self.backbone = self.backbone.to(self.device)

        importances = {}
        for ii, batch in enumerate(train_loader):
            self.unpack_batch(batch)
            outs = self.backbone(self.x, self.t)
            loss = super().compute_loss(outs, self.y, self.t)
            loss.backward()

            for (name, p) in self.backbone.named_parameters():
                name = name.replace(".", "_")
                if p.grad is not None:
                    if getattr(importances, name, None) is None:
                        importances[name] = p.grad.data.clone().pow(2) / len(train_loader)
                    else:
                        importances[name] += p.grad.data.clone().pow(2) / len(train_loader)

        return importances

__init__(ewc_lambda=1.0, *args, **kwargs)

Inits the Elastic Weight Consolidation algorithm.

Parameters:

Name Type Description Default
ewc_lambda float

The lambda coefficient of EWC algorithm.

1.0
Source code in sequel/algos/pytorch/ewc.py
13
14
15
16
17
18
19
def __init__(self, ewc_lambda: float = 1.0, *args, **kwargs):
    """Inits the Elastic Weight Consolidation algorithm.

    Args:
        ewc_lambda (float): The lambda coefficient of EWC algorithm.
    """
    super().__init__(regularization_coefficient=ewc_lambda, *args, **kwargs)