Bases: MetricCallback
Handles the computation and logging of metrics.
Callback hooks after train/val/test steps/epochs etc. Inherits from Callback.
Source code in sequel/utils/callbacks/metrics/jax_metric_callback.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51 | class JaxMetricCallback(MetricCallback):
"""Handles the computation and logging of metrics.
Callback hooks after train/val/test steps/epochs etc. Inherits from Callback.
"""
def __init__(self, metrics, logging_freq=10):
super().__init__(metrics, logging_freq)
def on_before_fit(self, algo: "JaxBaseAlgorithm", *args, **kwargs):
super().on_before_fit(algo, *args, **kwargs)
self.forgetting = [ForgettingMetric() for i in range(self.num_tasks)]
def identify_seen_tasks(self, algo) -> List[int]:
return jnp.unique(algo.t - 1).tolist()
def _reset_metrics(self, prefix):
self.metrics = [
self.original_metrics.clone(postfix=f"/{self.get_task_id(i)}", prefix=prefix)
for i in range(self.num_tasks)
]
self.avg_loss = MeanMetric()
def compute_mask(self, algo, task_id):
mask = super().compute_mask(algo, task_id)
return np.array(mask)
def on_after_val_step(self, algo):
self.avg_loss(algo.loss)
tasks_seen = self.identify_seen_tasks(algo)
assert len(tasks_seen) == 1
task_id = tasks_seen[0]
self.metrics[task_id](algo.y_hat, algo.y)
if (algo.batch_idx + 1) % self.logging_freq == 0:
msg: dict = self.metrics[algo.current_val_task - 1].compute()
msg.update({"val/avg_loss": self.avg_loss.compute()})
msg = self.register_metric_callback_message(msg, algo)
|