Welcome to the repo!
Getting started
Installation
conda create -n sequel python=3.10 -y
conda activate sequel
pip install -r requirements.txt
Launching the docs
# navigate to the root of the repo, i.e,
# where the file `mkdocs.yml` resides.
mkdocs serve
# Docs are launched in http://127.0.0.1:8000/
# navigate to the `site/` directory
# The directory contains the file 'index.html'
# serve the website
python3 -m http.server
# The website is hosted in http://127.0.0.1:8000/
Project layout
mkdocs.yml # The configuration file for the documentation.
docs/
index.md # The documentation homepage.
... # Other markdown pages, images and other files.
sequel/ # The source code lies here.
algos/ # The Continual Learning Algorithms, e.g. EWC.
backbones/ # The Neural Net classes.
benchmarks/ # The benchmarks such as SplitMNIST.
utils/ # Utility functions such as logging, callbacks etc.
Examples
The API for both JAX and PyTorch is the same. In the following example, we only need to change pytorch
to jax
and define the optimizer in a framework-specific way.
from sequel import benchmarks, backbones, algos, loggers, callbacks
import torch
# define the Continual Learning benchmark.
benchmark = benchmarks.PermutedMNIST(num_tasks=3, batch_size=512)
# define the backbone model, i.e., the neural network, and the optimizer
backbone = backbones.pytorch.MLP(width=256, n_hidden_layers=2, num_classes=10)
optimizer = torch.optim.SGD(backbone.parameters(), lr=0.1)
# initialize the algorithm
algo = algos.pytorch.EWC(
backbone=backbone,
optimizer=optimizer,
benchmark=benchmark,
callbacks=[
callbacks.PyTorchMetricCallback(),
callbacks.TqdmCallback(),
],
loggers=[loggers.WandbLogger()],
# algorithm-specific arguments
ewc_lambda=1,
)
# start training
algo.fit(epochs=1)
from sequel import benchmarks, backbones, algos, loggers, callbacks
import optax as tx
# define the Continual Learning benchmark.
benchmark = benchmarks.PermutedMNIST(num_tasks=3, batch_size=512)
# define the backbone model, i.e., the neural network, and the optimizer
backbone = backbones.jax.MLP(width=256, n_hidden_layers=2, num_classes=10)
optimizer = tx.inject_hyperparams(tx.sgd)(learning_rate=0.1)
# initialize the algorithm
algo = algos.jax.EWC(
backbone=backbone,
optimizer=optimizer,
benchmark=benchmark,
callbacks=[
callbacks.JaxMetricCallback(),
callbacks.TqdmCallback(),
],
loggers=[loggers.WandbLogger()],
# algorithm-specific arguments
ewc_lambda=1,
)
# start training
algo.fit(epochs=1)