Skip to content

How to get started

Examples

The API for both JAX and PyTorch is the same. In the following example, we only need to change pytorch to jax and define the optimizer in a framework-specific way.

from sequel import benchmarks, backbones, algos, loggers, callbacks
import torch

# define the Continual Learning benchmark.
benchmark = benchmarks.PermutedMNIST(num_tasks=3, batch_size=512)

# define the backbone model, i.e., the neural network, and the optimizer
backbone = backbones.pytorch.MLP(width=256, n_hidden_layers=2, num_classes=10)
optimizer = torch.optim.SGD(backbone.parameters(), lr=0.1)

# initialize the algorithm
algo = algos.pytorch.EWC(
    backbone=backbone,
    optimizer=optimizer,
    benchmark=benchmark,
    callbacks=[
        callbacks.PyTorchMetricCallback(),
        callbacks.TqdmCallback(),
    ],
    loggers=[loggers.WandbLogger()],
    # algorithm-specific arguments
    ewc_lambda=1,
)

# start training
algo.fit(epochs=1)
from sequel import benchmarks, backbones, algos, loggers, callbacks
import optax as tx

# define the Continual Learning benchmark.
benchmark = benchmarks.PermutedMNIST(num_tasks=3, batch_size=512)

# define the backbone model, i.e., the neural network, and the optimizer
backbone = backbones.jax.MLP(width=256, n_hidden_layers=2, num_classes=10)
optimizer = tx.inject_hyperparams(tx.sgd)(learning_rate=0.1)

# initialize the algorithm
algo = algos.jax.EWC(
    backbone=backbone,
    optimizer=optimizer,
    benchmark=benchmark,
    callbacks=[
        callbacks.JaxMetricCallback(),
        callbacks.TqdmCallback(),
    ],
    loggers=[loggers.WandbLogger()],
    # algorithm-specific arguments
    ewc_lambda=1,
)

# start training
algo.fit(epochs=1)