Joint
JointTraining
Bases: JaxBaseAlgorithm
The JoinTraining algorithm. It is a variant of MultiTask Learning, where the model is trained with increasingly more samples. Specifically, during the t-th task, the model sees samples from all the previous and the current task.
Inherits from BaseAlgorithm. Only the prepare_train_loader
method is overwritten.
The equivalent PyTorch implementation is JointTraining in Pytorch
.
Source code in sequel/algos/jax/joint.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
prepare_train_loader(task_id, batch_size=None)
Prepares the train_loader for Joint Training. The dataloader consists of all samples up to task task_id
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task_id |
int
|
The last task to be loaded. |
required |
batch_size |
Optional[int]
|
The dataloader batch size. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The JointTraining train dataloder. |
Source code in sequel/algos/jax/joint.py
24 25 26 27 28 29 30 31 32 33 34 |
|