Skip to content

MAS

MAS

Bases: PytorchRegularizationBaseAlgorithm

Memory Aware Synapses. Algorithm Class. Inherits from BaseAlgorithm.

The equivalent JAX implementation is MAS in JAX.

References

[1] Aljundi, R., Babiloni, F., Elhoseiny, M., Rohrbach, M. & Tuytelaars, T. Memory Aware Synapses: Learning What (not) to Forget. in Computer Vision - ECCV 2018.

Source code in sequel/algos/pytorch/mas.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class MAS(PytorchRegularizationBaseAlgorithm):
    """Memory Aware Synapses. Algorithm Class. Inherits from BaseAlgorithm.

    The equivalent JAX implementation is [`MAS in JAX`][sequel.algos.jax.mas.MAS].

    References:
        [1] Aljundi, R., Babiloni, F., Elhoseiny, M., Rohrbach, M. & Tuytelaars, T. Memory Aware Synapses: Learning
            What (not) to Forget. in Computer Vision - ECCV 2018.
    """

    def __init__(self, mas_lambda: float = 1.0, *args, **kwargs):
        """Inits the Memory Aware Synapses algorithm.

        Args:
            mas_lambda (float): The c coefficient of the algorithm.
        """
        super().__init__(regularization_coefficient=mas_lambda, *args, **kwargs)

        torch.autograd.set_detect_anomaly(True)
        for name, param in self.backbone.named_parameters():
            name = name.replace(".", "_")
            self.backbone.register_buffer(f"{name}_w", torch.zeros_like(param))

    def __repr__(self) -> str:
        return f"MAS(mas_lambda={self.regularization_coefficient})"

    def on_after_training_step(self, *args, **kwargs):
        # perform the forward pass once again with the new parameters.
        self.forward()
        self.optimizer_zero_grad()
        f_loss: torch.Tensor = self.y_hat.pow_(2).mean()
        f_loss.backward()
        for name, param in self.backbone.named_parameters():
            name = name.replace(".", "_")
            w = getattr(self.backbone, f"{name}_w")
            if param.grad is not None:
                setattr(self.backbone, f"{name}_w", w + param.grad.abs() / len(self.train_loader))
        return super().on_after_training_step(*args, **kwargs)

    def calculate_parameter_importance(self):
        importances = {}
        for name, param in self.backbone.named_parameters():
            name = name.replace(".", "_")
            importances[name] = getattr(self.backbone, f"{name}_w")

        return importances

__init__(mas_lambda=1.0, *args, **kwargs)

Inits the Memory Aware Synapses algorithm.

Parameters:

Name Type Description Default
mas_lambda float

The c coefficient of the algorithm.

1.0
Source code in sequel/algos/pytorch/mas.py
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self, mas_lambda: float = 1.0, *args, **kwargs):
    """Inits the Memory Aware Synapses algorithm.

    Args:
        mas_lambda (float): The c coefficient of the algorithm.
    """
    super().__init__(regularization_coefficient=mas_lambda, *args, **kwargs)

    torch.autograd.set_detect_anomaly(True)
    for name, param in self.backbone.named_parameters():
        name = name.replace(".", "_")
        self.backbone.register_buffer(f"{name}_w", torch.zeros_like(param))