AGEM
AGEM
Bases: JaxBaseAlgorithm
A-GEM: Averaged-Gradient Episodic Memory. Maintains a memory of samples from past tasks. The gradients for the current batch are projected to the convex hull of the task gradients produced by the the aforementioned memory. Inherits from BaseAlgorithm.
The equivalent PyTorch implementation is A-GEM in Pytorch
.
References
[1] Chaudhry, A., Ranzato, M., Rohrbach, M. & Elhoseiny, M. Efficient Lifelong Learning with A-GEM. in 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019.
Source code in sequel/algos/jax/agem.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
|
__init__(per_task_memory_samples, memory_batch_size, memory_group_by, *args, **kwargs)
Inits the A-GEM algorithm class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
per_task_memory_samples |
int
|
number of exemplars per experience in the memory. |
required |
memory_batch_size |
int
|
the batch size of the memory samples used to modify the gradient update. |
required |
memory_group_by |
Literal['task', 'class']
|
Determines the selection process of samples for the memory. |
required |
Source code in sequel/algos/jax/agem.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
|
agem_training_step(state, x, y, t, mem_x, mem_y, mem_t, step)
The A-GEM training step that uses the memory samples to modify the gradient.
Note
this implementation is suboptimal since it computes mem_norm and performs the tree_map operation even if not needed (case of dotg nonnegative). However, it has been implemented in this way in order to jit in a single function the gradient updates.
Source code in sequel/algos/jax/agem.py
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
|