Skip to content

MetricCallback

MetricCallback

Bases: AlgoCallback

MetricCallback is the parent clas for the PyTorch and Jax metric callbacks. Handles the computation of metrics during training, validation etc.

Source code in sequel/utils/callbacks/metrics/metric_callback.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class MetricCallback(AlgoCallback):
    """MetricCallback is the parent clas for the PyTorch and Jax metric callbacks. Handles the computation of metrics
    during training, validation etc."""

    forgetting: List[Union[torchmetrics.Metric, JaxMetric]]

    def __init__(self, metrics, logging_freq=10):
        super().__init__()
        self.logging_freq = logging_freq
        self.original_metrics = metrics.clone()

    def on_before_fit(self, algo: "BaseAlgorithm", *args, **kwargs):
        self.num_tasks = algo.num_tasks
        return super().on_before_fit(algo, *args, **kwargs)

    def _reset_metrics(self, prefix):
        raise NotImplementedError

    def log(self, algo: "BaseAlgorithm", key, value):
        algo.log({key: value})

    def get_task_id(self, i):
        return f"task-{i}"

    def register_metric_callback_message(self, msg: dict, algo: "BaseAlgorithm"):
        # some weird bug due to wandb. it adds _timestamp and _runtime to the msg dict
        msg = {k: v for k, v in msg.items() if not k.startswith("_")}
        msg = {k.split("/")[1]: safe_conversion(v) for k, v in msg.items()}
        setattr(algo, "metric_callback_msg", msg)
        return msg

    def identify_seen_tasks(self, algo: "BaseAlgorithm") -> List[int]:
        raise NotImplementedError

    def compute_mask(self, algo, task_id):
        return (algo.t - 1) == task_id

    # ------- STEPS -------
    def on_after_training_step(self, algo: "BaseAlgorithm"):
        self.avg_loss(algo.loss)
        tasks_seen = self.identify_seen_tasks(algo)
        for task_id in tasks_seen:
            mask = self.compute_mask(algo, task_id)
            self.metrics[task_id](algo.y_hat[mask], algo.y[mask])

        if (algo.batch_idx + 1) % self.logging_freq == 0:
            msg: dict = self.metrics[algo.task_counter - 1].compute()
            msg.update({"train/avg_loss": self.avg_loss.compute()})
            algo.log(copy.deepcopy(msg))
            msg = self.register_metric_callback_message(msg, algo)

    def on_after_val_step(self, algo: "BaseAlgorithm"):
        self.avg_loss(algo.loss)
        tasks_seen = self.identify_seen_tasks(algo)
        for task_id in tasks_seen:
            mask = self.compute_mask(algo, task_id)
            self.metrics[task_id](algo.y_hat[mask], algo.y[mask])

        if (algo.batch_idx + 1) % self.logging_freq == 0:
            msg: dict = self.metrics[algo.current_val_task - 1].compute()
            msg.update({"val/avg_loss": self.avg_loss.compute()})
            msg = self.register_metric_callback_message(msg, algo)

    # ------- EPOCHS - before -------
    def on_before_training_epoch(self, *args, **kwargs):
        self._reset_metrics(prefix="train/")

    def on_before_val_epoch(self, *args, **kwargs):
        self._reset_metrics(prefix="val/")

    def on_before_testing_epoch(self, *args, **kwargs):
        self._reset_metrics(prefix="test/")

    def on_before_validating_algorithm_on_all_tasks(self, algo: "BaseAlgorithm", *args, **kwargs):
        self.metric_results = []

    # ------- EPOCHS - after -------
    def on_after_val_epoch(self, algo: "BaseAlgorithm", *args, **kwargs):
        res = self.metrics[algo.current_val_task - 1].compute()
        self.task_counter_metrics = res
        self.metric_results.append(res)

        to_log = copy.deepcopy(res)
        to_log["epoch"] = algo.epoch_counter
        algo.log(to_log)

        key = list(filter(lambda x: "acc" in x, list(self.task_counter_metrics.keys())))[0]
        self.forgetting[algo.current_val_task - 1](self.task_counter_metrics[key])

    def on_after_validating_algorithm_on_all_tasks(self, algo: "BaseAlgorithm", *args, **kwargs):
        forgetting = {
            f"val/forgetting/task-{i+1}": safe_conversion(k.compute())
            for i, k in enumerate(self.forgetting)
            if i < algo.task_counter
        }
        algo.log(forgetting)

        # compute averages
        avg = {}
        for key in self.original_metrics:
            temp = [[v for k, v in m.items() if key in k][0] for m in self.metric_results]
            avg[key] = safe_conversion(sum(temp)) / len(temp)

        avg = {f"avg/{k}": v for k, v in avg.items()}
        avg["avg/forgetting"] = sum(forgetting.values()) / len(forgetting)
        algo.log(avg)

        # only print table at the end of fitting one task
        if algo.epoch_counter % algo.epochs == 0:
            self.print_task_metrics(self.metric_results, epoch=algo.epoch_counter)
        else:
            logging.info({k: round(v, 3) for k, v in avg.items()})
            logging.info({k: round(safe_conversion(v), 3) for k, v in self.metric_results[-1].items()})

        _metrics = dict(ChainMap(*self.metric_results))
        _metrics = {k: round(safe_conversion(v), 3) for k, v in _metrics.items()}
        self.register_results_to_algo(algo, "val_metrics", _metrics)

    def on_after_fit(self, algo: "BaseAlgorithm", *args, **kwargs):
        if algo.loggers is not None:
            for logger in algo.loggers:
                logger.log_all_results()

    def register_results_to_algo(self, algo, results_name, results_dict):
        setattr(algo, results_name, results_dict)

    def print_task_metrics(self, metrics: list, epoch):
        table = BeautifulTable(default_alignment=BeautifulTable.ALIGN_LEFT, default_padding=1, maxwidth=250)
        table.set_style(BeautifulTable.STYLE_BOX_ROUNDED)
        keys = list(metrics[0].keys())
        keys = [k.split("/")[1] for k in keys]
        table.rows.header = keys
        for i, m in enumerate(metrics):
            column_name = f"Task-{i+1}"
            table.columns.append(m.values(), header=column_name)
            table.columns.alignment[column_name] = BeautifulTable.ALIGN_RIGHT

        avg = {}
        for key in keys:
            temp = [[v for k, v in m.items() if key in k][0] for m in metrics]
            avg[key] = sum(temp) / len(temp)

        table.columns.append(avg.values(), header="AVG")
        table.columns.alignment["AVG"] = BeautifulTable.ALIGN_RIGHT
        f = [safe_conversion(k.compute()) for i, k in enumerate(self.forgetting) if i < len(metrics)]
        f.append(sum(f) / len(f))
        table.rows.append(f, header="Forgetting")
        logging.info(f"EVAL METRICS for epoch {epoch}:\n{table}")

BackwardTranferMetric

Bases: ForgettingMetric

How much learning the current experience improves my performance on previous experiences?

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
53
54
55
56
57
class BackwardTranferMetric(ForgettingMetric):
    """How much learning the current experience improves my performance on previous experiences?"""

    def compute(self):
        return -super().compute()

CrossEntropyLossMetric

Bases: MeanMetric

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
15
16
17
18
19
20
21
22
23
24
25
26
27
class CrossEntropyLossMetric(MeanMetric):
    def update(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Wraps CrossEntropy into a torchmetrics MeanMetric.

        Args:
            preds (torch.Tensor): the logits of the current batch.
            target (torch.Tensor): the targets of the current batch.

        Returns:
            torch.Tensor: the computed cross-entropy loss.
        """
        value = F.cross_entropy(input=preds, target=target)
        return super().update(value)

update(preds, target)

Wraps CrossEntropy into a torchmetrics MeanMetric.

Parameters:

Name Type Description Default
preds torch.Tensor

the logits of the current batch.

required
target torch.Tensor

the targets of the current batch.

required

Returns:

Type Description
torch.Tensor

torch.Tensor: the computed cross-entropy loss.

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
16
17
18
19
20
21
22
23
24
25
26
27
def update(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Wraps CrossEntropy into a torchmetrics MeanMetric.

    Args:
        preds (torch.Tensor): the logits of the current batch.
        target (torch.Tensor): the targets of the current batch.

    Returns:
        torch.Tensor: the computed cross-entropy loss.
    """
    value = F.cross_entropy(input=preds, target=target)
    return super().update(value)

PytorchMetricCallback

Bases: MetricCallback

Base class for the MetricCallback in case of PyTorch.

Inherits from MetricCallback.

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class PytorchMetricCallback(MetricCallback):
    """Base class for the MetricCallback in case of PyTorch.

    Inherits from `MetricCallback`.
    """

    def __init__(self, metrics: MetricCollection, logging_freq=10):
        super().__init__(metrics, logging_freq)

    def on_before_fit(self, algo: "BaseAlgorithm", *args, **kwargs):
        super().on_before_fit(algo, *args, **kwargs)
        self.forgetting = [ForgettingMetric() for i in range(self.num_tasks)]

    def connect(self, algo, *args, **kwargs):
        self.device = algo.device
        super().connect(algo, *args, **kwargs)

    def identify_seen_tasks(self, algo) -> List[int]:
        return torch.unique(algo.t - 1)

    def _reset_metrics(self, prefix):
        self.metrics = nn.ModuleList(
            self.original_metrics.clone(postfix=f"/{self.get_task_id(i)}", prefix=prefix).to(self.device)
            for i in range(self.num_tasks)
        ).to(self.device)
        self.avg_loss = MeanMetric(prefix=prefix).to(self.device)

StandardMetricCallback

Bases: PytorchMetricCallback

Defines the standard Metric Callback used for classificaiton.

Source code in sequel/utils/callbacks/metrics/pytorch_metric_callback.py
88
89
90
91
92
93
94
95
96
97
98
class StandardMetricCallback(PytorchMetricCallback):
    """Defines the standard Metric Callback used for classificaiton."""

    def __init__(self, logging_freq=1):
        metrics = torchmetrics.MetricCollection(
            {
                "acc": torchmetrics.Accuracy(),
                "loss": CrossEntropyLossMetric(),
            },
        )
        super().__init__(metrics, logging_freq)

JaxMetricCallback

Bases: MetricCallback

Handles the computation and logging of metrics.

Callback hooks after train/val/test steps/epochs etc. Inherits from Callback.

Source code in sequel/utils/callbacks/metrics/jax_metric_callback.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class JaxMetricCallback(MetricCallback):
    """Handles the computation and logging of metrics.

    Callback hooks after train/val/test steps/epochs etc. Inherits from Callback.
    """

    def __init__(self, metrics, logging_freq=10):
        super().__init__(metrics, logging_freq)

    def on_before_fit(self, algo: "JaxBaseAlgorithm", *args, **kwargs):
        super().on_before_fit(algo, *args, **kwargs)
        self.forgetting = [ForgettingMetric() for i in range(self.num_tasks)]

    def identify_seen_tasks(self, algo) -> List[int]:
        return jnp.unique(algo.t - 1).tolist()

    def _reset_metrics(self, prefix):
        self.metrics = [
            self.original_metrics.clone(postfix=f"/{self.get_task_id(i)}", prefix=prefix)
            for i in range(self.num_tasks)
        ]
        self.avg_loss = MeanMetric()

    def compute_mask(self, algo, task_id):
        mask = super().compute_mask(algo, task_id)
        return np.array(mask)

    def on_after_val_step(self, algo):
        self.avg_loss(algo.loss)
        tasks_seen = self.identify_seen_tasks(algo)
        assert len(tasks_seen) == 1
        task_id = tasks_seen[0]
        self.metrics[task_id](algo.y_hat, algo.y)

        if (algo.batch_idx + 1) % self.logging_freq == 0:
            msg: dict = self.metrics[algo.current_val_task - 1].compute()
            msg.update({"val/avg_loss": self.avg_loss.compute()})
            msg = self.register_metric_callback_message(msg, algo)