LFL
LFL
Bases: JaxBaseAlgorithm
Less-Forgetting Learning implementation in JAX.
The equivalent PyTorch implementation is LFL in Pytorch
.
References
[1] Jung, H., Ju, J., Jung, M. & Kim, J. Less-forgetful learning for domain expansion in deep neural networks. Proceedings of the AAAI Conference on Artificial Intelligence 32, (2018).
Source code in sequel/algos/jax/lfl.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
__init__(lfl_lambda, *args, **kwargs)
Inits the LFL class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
lfl_lambda |
float
|
the regularization coefficient. |
required |
Source code in sequel/algos/jax/lfl.py
20 21 22 23 24 25 26 27 |
|
lfl_training_step(state, x, y, t)
Train for a single step.
Source code in sequel/algos/jax/lfl.py
63 64 65 66 67 68 69 |
|