Skip to content

PytorchMetricCallback

BackwardTranferMetric

Bases: ForgettingMetric

How much learning the current experience improves my performance on previous experiences?

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
53
54
55
56
57
class BackwardTranferMetric(ForgettingMetric):
    """How much learning the current experience improves my performance on previous experiences?"""

    def compute(self):
        return -super().compute()

CrossEntropyLossMetric

Bases: MeanMetric

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
15
16
17
18
19
20
21
22
23
24
25
26
27
class CrossEntropyLossMetric(MeanMetric):
    def update(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Wraps CrossEntropy into a torchmetrics MeanMetric.

        Args:
            preds (torch.Tensor): the logits of the current batch.
            target (torch.Tensor): the targets of the current batch.

        Returns:
            torch.Tensor: the computed cross-entropy loss.
        """
        value = F.cross_entropy(input=preds, target=target)
        return super().update(value)

update(preds, target)

Wraps CrossEntropy into a torchmetrics MeanMetric.

Parameters:

Name Type Description Default
preds torch.Tensor

the logits of the current batch.

required
target torch.Tensor

the targets of the current batch.

required

Returns:

Type Description
torch.Tensor

torch.Tensor: the computed cross-entropy loss.

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
16
17
18
19
20
21
22
23
24
25
26
27
def update(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Wraps CrossEntropy into a torchmetrics MeanMetric.

    Args:
        preds (torch.Tensor): the logits of the current batch.
        target (torch.Tensor): the targets of the current batch.

    Returns:
        torch.Tensor: the computed cross-entropy loss.
    """
    value = F.cross_entropy(input=preds, target=target)
    return super().update(value)

PytorchMetricCallback

Bases: MetricCallback

Base class for the MetricCallback in case of PyTorch.

Inherits from MetricCallback.

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class PytorchMetricCallback(MetricCallback):
    """Base class for the MetricCallback in case of PyTorch.

    Inherits from `MetricCallback`.
    """

    def __init__(self, metrics: MetricCollection, logging_freq=10):
        super().__init__(metrics, logging_freq)

    def on_before_fit(self, algo: "BaseAlgorithm", *args, **kwargs):
        super().on_before_fit(algo, *args, **kwargs)
        self.forgetting = [ForgettingMetric() for i in range(self.num_tasks)]

    def connect(self, algo, *args, **kwargs):
        self.device = algo.device
        super().connect(algo, *args, **kwargs)

    def identify_seen_tasks(self, algo) -> List[int]:
        return torch.unique(algo.t - 1)

    def _reset_metrics(self, prefix):
        self.metrics = nn.ModuleList(
            self.original_metrics.clone(postfix=f"/{self.get_task_id(i)}", prefix=prefix).to(self.device)
            for i in range(self.num_tasks)
        ).to(self.device)
        self.avg_loss = MeanMetric(prefix=prefix).to(self.device)

StandardMetricCallback

Bases: PytorchMetricCallback

Defines the standard Metric Callback used for classificaiton.

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
88
89
90
91
92
93
94
95
96
97
98
class StandardMetricCallback(PytorchMetricCallback):
    """Defines the standard Metric Callback used for classificaiton."""

    def __init__(self, logging_freq=1):
        metrics = torchmetrics.MetricCollection(
            {
                "acc": torchmetrics.Accuracy(),
                "loss": CrossEntropyLossMetric(),
            },
        )
        super().__init__(metrics, logging_freq)