Skip to content

MAS

MAS

Bases: JaxRegularizationBaseAlgorithm

Memory Aware Synapses. Algorithm Class. Inherits from BaseAlgorithm.

The equivalent PyTorch implementation is MAS in Pytorch.

References

[1] Aljundi, R., Babiloni, F., Elhoseiny, M., Rohrbach, M. & Tuytelaars, T. Memory Aware Synapses: Learning What (not) to Forget. in Computer Vision - ECCV 2018.

Source code in sequel/algos/jax/mas.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class MAS(JaxRegularizationBaseAlgorithm):
    """Memory Aware Synapses. Algorithm Class. Inherits from BaseAlgorithm.

    The equivalent PyTorch implementation is [`MAS in Pytorch`][sequel.algos.pytorch.mas.MAS].

    References:
        [1] Aljundi, R., Babiloni, F., Elhoseiny, M., Rohrbach, M. & Tuytelaars, T. Memory Aware Synapses: Learning
            What (not) to Forget. in Computer Vision - ECCV 2018.
    """

    def __init__(self, mas_lambda: float = 1.0, *args, **kwargs):
        super().__init__(regularization_coefficient=mas_lambda, *args, **kwargs)
        self.w = jax.tree_map(lambda x: 0 * x, self.state.params)

    def __repr__(self) -> str:
        return f"MAS(mas_lambda={self.regularization_coefficient})"

    def calculate_parameter_importance(self):
        if self.task_counter == 1:
            importance = jax.tree_map(lambda x: 0 * x, self.state.params)
        else:
            importance = self.importance
        importance = jax.tree_map(lambda i, w: i + w, importance, self.w)
        self.w = jax.tree_map(lambda x: 0 * x, self.state.params)
        return importance

    def on_before_training_step(self, *args, **kwargs):
        self.old_params = copy.deepcopy(self.state.params)

    def on_after_training_step(self, *args, **kwargs):
        @jax.jit
        def secondary_loss(params, x, t, training=True):
            logits = self.apply_fn(params, x, t, training=training)
            loss = jnp.mean(jnp.square(logits))
            return loss, logits

        grad_fn = jax.value_and_grad(secondary_loss, has_aux=True, allow_int=True)
        _, grads = grad_fn(self.state.params, self.x, self.t, self.is_training)
        self.w = jax.tree_map(lambda w, g: w + jnp.abs(g) / len(self.y), self.w, grads)

    def on_after_training_task(self, *args, **kwargs):
        self.old_params = copy.deepcopy(self.state.params)
        self.importance = self.calculate_parameter_importance()