Base
BaseBackbone
Bases: nn.Module
Inits the BaseBackbone class. This class defines the Jax base class for neural networks. All models should inherit from this class. Inherits from flax.nn.Module and the BaseCallback class that endows callbacks for each stage of training, e.g., before and after trainining/validation steps/epochs/tasks etc.
Attributes:
Name | Type | Description |
---|---|---|
multihead |
bool
|
Set to True if the backbone is multi-headed. Defaults to False. |
classes_per_task |
Optional[int]
|
The number of classes per task. Defaults to None. |
masking_value |
float
|
The value that replaces the logits. Only used if multihead is set to True. Defaults to -10e10. |
Note
Currently, the BaseBackbone only considers tasks with equal number of classes.
Source code in sequel/backbones/jax/base_backbone.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
|