Base
BaseBackbone
Bases: nn.Module
The PyTorch base class for neural networks.
Inherits from torch.nn.Module and the BaseCallback class that endows callbacks for each stage of training, e.g., before and after trainining/validation steps/epochs/tasks etc.
Source code in sequel/backbones/pytorch/base_backbone.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
__init__(multihead=False, classes_per_task=None, masking_value=-100000000000.0)
Inits the BaseBackbone class. This class defines the PyTorch base class for neural networks. All models should inherit from this class. Inherits from torch.nn.Module and the BaseCallback class that endows callbacks for each stage of training, e.g., before and after trainining/validation steps/epochs/tasks etc.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
multihead |
bool
|
Set to True if the backbone is multi-headed. Defaults to False. |
False
|
classes_per_task |
Optional[int]
|
The number of classes per task. Defaults to None. |
None
|
masking_value |
float
|
The value that replaces the logits. Only used if multihead is set to True. Defaults to -10e10. |
-100000000000.0
|
Note
Currently, the BaseBackbone only considers tasks with equal number of classes.
Source code in sequel/backbones/pytorch/base_backbone.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
forward(x, task_ids)
Implements the forward function of the backbone. Models must ovveride this method.
Example
perform the forward.
x = ...
select the correct output head.
if self.multihead: return self.select_output_head(x, task_ids)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
The batch inputs. |
required |
task_ids |
torch.Tensor
|
The batch task ids. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
torch.Tensor: The batch predicitons. |
Source code in sequel/backbones/pytorch/base_backbone.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
select_output_head(x, task_ids)
Utility function in case multihead=True
that replaces the original logits by a low value so that almost
zero probability is given to the corresponding classes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
The original logits. |
required |
task_ids |
torch.Tensor
|
The task id for each sample in the batch. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
torch.Tensor: the manipulated logits. |
Source code in sequel/backbones/pytorch/base_backbone.py
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
|