Skip to content

DER

DER

Bases: JaxBaseAlgorithm

Dark Experience Replay algorithm implemented in JAX.

The equivalent PyTorch implementation is DER in Pytorch.

References

[1] Buzzega, P., Boschini, M., Porrello, A., Abati, D. & Calderara, S. Dark experience for general continual learning: a strong, simple baseline. in Advances in neural information processing systems 2020.

Source code in sequel/algos/jax/der.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class DER(JaxBaseAlgorithm):
    """Dark Experience Replay algorithm implemented in JAX.

    The equivalent PyTorch implementation is [`DER in Pytorch`][sequel.algos.pytorch.der.DER].

    References:
        [1] Buzzega, P., Boschini, M., Porrello, A., Abati, D. & Calderara, S. Dark experience for general continual
            learning: a strong, simple baseline. in Advances in neural information processing systems 2020.
    """

    def __init__(self, memory_size: int, alpha: float, beta: Optional[float] = None, *args, **kwargs):
        """Inits the DER class. Implements the Dark Experience Replay algorithm.

        Args:
            memory_size (int): The size of the memory.
            alpha (float): The regularization coefficient for the DER objective.
            beta (Optional[float], optional): The regulrization coefficent for the DER++ objective. If set to None or
                zero, the algorithm corresponds to DER. Defaults to None.
        """

        super().__init__(*args, **kwargs)
        self.buffer = Buffer(memory_size=memory_size, return_logits=True)
        self.memory_size = memory_size
        self.alpha = alpha

        # Beta is used for DER++
        self.beta = beta

    def __repr__(self) -> str:
        if self.beta is None:
            return f"DER(memory_size={self.memory_size}, alpha={self.alpha})"
        else:
            return f"DER++(memory_size={self.memory_size}, alpha={self.alpha}, beta={self.beta})"

    @partial(jax.jit, static_argnums=(0,))
    def der_loss(self, params, x, y, t):
        # TODO: add task id support
        dropout_train_key = jax.random.fold_in(key=self.dropout_key, data=self.state.step)
        logits = self.apply_fn(params, x=x, training=self.is_training, rngs={"dropout": dropout_train_key})
        loss = cross_entropy_loss(logits=logits, labels=y)
        # DER LOSS
        dropout_key = jax.random.fold_in(key=dropout_train_key, data=self.state.step)
        mem_y_hat = self.apply_fn(params, x=x, training=self.is_training, rngs={"dropout": dropout_key})
        der_loss = jnp.mean((self.mem_logits - mem_y_hat) ** 2)

        loss += self.alpha * der_loss
        return loss, logits

    @partial(jax.jit, static_argnums=(0,))
    def derpp_loss(self, params, x, y, t):
        # TODO: add task id support
        dropout_key = jax.random.fold_in(key=self.dropout_key, data=self.state.step)
        logits = self.apply_fn(params, x=x, t=t, training=self.is_training, rngs={"dropout": dropout_key})
        loss = cross_entropy_loss(logits=logits, labels=y)
        # DER LOSS
        dropout_key = jax.random.fold_in(key=dropout_key, data=self.state.step)
        mem_y_hat = self.apply_fn(params, x=self.mem_x, training=self.is_training, rngs={"dropout": dropout_key})
        der_loss = jnp.mean((self.mem_logits - mem_y_hat) ** 2)

        # DER++ LOSS
        dropout_key = jax.random.fold_in(key=dropout_key, data=self.state.step)
        mem_y_hat2 = self.apply_fn(params, x=x, training=self.is_training, rngs={"dropout": dropout_key})
        derpp_loss = cross_entropy_loss(logits=mem_y_hat2, labels=self.mem_y2)

        loss += self.alpha * der_loss + self.beta * derpp_loss
        return loss, logits

    @partial(jax.jit, static_argnums=(0, 5))
    def custom_training_step(self, state: TrainState, x, y, t, loss_fn):
        """Train for a single step."""
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True, allow_int=True)
        (loss, logits), grads = grad_fn(state.params, x=x, y=y, t=t)
        state = state.apply_gradients(grads=grads)
        return dict(state=state, logits=logits, loss=loss, grads=grads)

    def training_step(self, *args, **kwargs):
        if self.task_counter == 1:
            self.batch_outputs = self.base_training_step(self.state, self.x, self.y, self.t)
        else:
            x, y, t, logits = self.buffer.sample_from_buffer(batch_size=self.benchmark.batch_size)
            self.mem_x, self.mem_y, self.mem_t, self.mem_logits = x, y, t, logits
            if self.beta is None:
                self.batch_outputs = self.custom_training_step(self.state, self.x, self.y, self.t, self.der_loss)
            else:
                x, y, t, _ = self.buffer.sample_from_buffer(batch_size=self.benchmark.batch_size)
                self.mem_x2, self.mem_y2, self.mem_t2 = x, y, t
                self.batch_outputs = self.custom_training_step(self.state, self.x, self.y, self.t, self.derpp_loss)
        self.register_batch_outputs(self.batch_outputs)

    def on_after_training_step(self, *args, **kwargs):
        self.buffer.add_data(self.x, self.y, self.t, self.y_hat)

__init__(memory_size, alpha, beta=None, *args, **kwargs)

Inits the DER class. Implements the Dark Experience Replay algorithm.

Parameters:

Name Type Description Default
memory_size int

The size of the memory.

required
alpha float

The regularization coefficient for the DER objective.

required
beta Optional[float]

The regulrization coefficent for the DER++ objective. If set to None or zero, the algorithm corresponds to DER. Defaults to None.

None
Source code in sequel/algos/jax/der.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(self, memory_size: int, alpha: float, beta: Optional[float] = None, *args, **kwargs):
    """Inits the DER class. Implements the Dark Experience Replay algorithm.

    Args:
        memory_size (int): The size of the memory.
        alpha (float): The regularization coefficient for the DER objective.
        beta (Optional[float], optional): The regulrization coefficent for the DER++ objective. If set to None or
            zero, the algorithm corresponds to DER. Defaults to None.
    """

    super().__init__(*args, **kwargs)
    self.buffer = Buffer(memory_size=memory_size, return_logits=True)
    self.memory_size = memory_size
    self.alpha = alpha

    # Beta is used for DER++
    self.beta = beta

custom_training_step(state, x, y, t, loss_fn)

Train for a single step.

Source code in sequel/algos/jax/der.py
81
82
83
84
85
86
87
@partial(jax.jit, static_argnums=(0, 5))
def custom_training_step(self, state: TrainState, x, y, t, loss_fn):
    """Train for a single step."""
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True, allow_int=True)
    (loss, logits), grads = grad_fn(state.params, x=x, y=y, t=t)
    state = state.apply_gradients(grads=grads)
    return dict(state=state, logits=logits, loss=loss, grads=grads)