Skip to content

DER

DER

Bases: PytorchBaseAlgorithm

Dark Experience Replay algorithm implemented in PyTorch.

The equivalent JAX implementation is DER in JAX.

References

[1] Buzzega, P., Boschini, M., Porrello, A., Abati, D. & Calderara, S. Dark experience for general continual learning: a strong, simple baseline. in Advances in neural information processing systems 2020.

Source code in sequel/algos/pytorch/der.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class DER(PytorchBaseAlgorithm):
    """Dark Experience Replay algorithm implemented in PyTorch.

    The equivalent JAX implementation is [`DER in JAX`][sequel.algos.jax.der.DER].

    References:
        [1] Buzzega, P., Boschini, M., Porrello, A., Abati, D. & Calderara, S. Dark experience for general continual
            learning: a strong, simple baseline. in Advances in neural information processing systems 2020.
    """

    def __init__(self, memory_size: int, alpha: float, beta: Optional[float] = None, *args, **kwargs):

        super().__init__(*args, **kwargs)
        self.buffer = Buffer(memory_size=memory_size, return_logits=True)
        self.memory_size = memory_size
        self.alpha = alpha

        # Beta is used for DER++
        self.beta = beta

    def on_before_backward(self, *args, **kwargs):
        if len(self.buffer) > 0:
            # if self.task_counter > 1:
            x, y, t, logits = self.buffer.sample_from_buffer(batch_size=self.benchmark.batch_size)
            self.mem_x, self.mem_y, self.mem_t, self.mem_logits = x, y, t, logits
            self.mem_y_hat = self.backbone(self.mem_x, self.mem_t)
            loss = F.mse_loss(self.mem_y_hat, self.mem_logits)
            self.loss += self.alpha * loss

            if self.beta is not None:
                x, y, t, _ = self.buffer.sample_from_buffer(batch_size=self.benchmark.batch_size)
                self.mem_x2, self.mem_y2, self.mem_t2 = x.to(self.device), y.to(self.device), t.to(self.device)
                self.mem_y_hat2 = self.backbone(self.mem_x2, self.mem_t2)
                self.loss += self.beta * self.compute_loss(self.mem_y_hat2, self.mem_y2, self.mem_t2)

    def on_before_optimizer_step(self, *args, **kwargs):
        if self.task_counter > 1:

            return super().on_before_optimizer_step(*args, **kwargs)

    def on_after_training_step(self, *args, **kwargs):
        self.buffer.add_data(self.x, self.y, self.t, self.y_hat.data)

    def __repr__(self) -> str:
        if self.beta is None:
            return f"DER(memory_size={self.memory_size}, alpha={self.alpha})"
        else:
            return f"DER++(memory_size={self.memory_size}, alpha={self.alpha}, beta={self.beta})"