Skip to content

InputVisualizationCallback

InputVisualizationCallback

Bases: AlgoCallback

Visualizes random samples from each task and uses the loggers to save the plots.

Source code in sequel/utils/callbacks/input_visualization_callback.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class InputVisualizationCallback(AlgoCallback):
    """Visualizes random samples from each task and uses the loggers to save the plots."""

    def __init__(self, samples_per_task=5):
        """Inits the InputVisualizationCallback.

        Args:
            samples_per_task (int, optional): number of samples to be saved for each tasks. Defaults to 5.
        """
        super().__init__()
        self.samples_per_task = samples_per_task

    def select_random_samples(self, dataset: torch.utils.data.Dataset) -> List[torch.Tensor]:
        """Selects a prefefined number of samples per each CL dataset. Each task corresponds to a different dataset.

        Args:
            dataset (torch.data.utils.Dataset): The PyTorch Datatet.

        Returns:
            List[torch.Tensor]: The Tensors corresponding to the selected input samples.
        """
        indices = np.random.choice(len(dataset), self.samples_per_task, replace=False)
        samples = [dataset[i] for i in indices]
        return samples

    def on_before_fit(self, algo: "BaseAlgorithm", *args, **kwargs) -> None:
        """Retrieves and diplays in a single plot the input images from all tasks of the benchmark that the algorithm
        has been initialized with. The final plot is saved via the loggers.

        Args:
            algo (BaseAlgorithm): The BaseAlgorithm instance.
        """
        datasets = algo.benchmark.trains
        num_tasks = algo.num_tasks

        samples = []
        for dataset in datasets.values():
            task_samples = self.select_random_samples(dataset)
            samples.append(task_samples)

        s = 2
        figure, axes = plt.subplots(
            nrows=num_tasks,
            ncols=self.samples_per_task,
            figsize=(s * self.samples_per_task, s * num_tasks),
        )

        for i, task_samples in enumerate(samples):
            for j, (x, y, t) in enumerate(task_samples):
                if x.dim() == 2:
                    x = x.unsqueeze(0)
                axes[i][j].imshow(x.permute(1, 2, 0))
                axes[i][j].title.set_text(f"t={t}: y={y}")

        plt.setp(axes, xticks=[], yticks=[])
        figure.subplots_adjust(wspace=0.5)

        # save the plot via the algorithm loggers
        algo.log_figure(name="input/viz", figure=figure)

__init__(samples_per_task=5)

Inits the InputVisualizationCallback.

Parameters:

Name Type Description Default
samples_per_task int

number of samples to be saved for each tasks. Defaults to 5.

5
Source code in sequel/utils/callbacks/input_visualization_callback.py
16
17
18
19
20
21
22
23
def __init__(self, samples_per_task=5):
    """Inits the InputVisualizationCallback.

    Args:
        samples_per_task (int, optional): number of samples to be saved for each tasks. Defaults to 5.
    """
    super().__init__()
    self.samples_per_task = samples_per_task

on_before_fit(algo, *args, **kwargs)

Retrieves and diplays in a single plot the input images from all tasks of the benchmark that the algorithm has been initialized with. The final plot is saved via the loggers.

Parameters:

Name Type Description Default
algo BaseAlgorithm

The BaseAlgorithm instance.

required
Source code in sequel/utils/callbacks/input_visualization_callback.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def on_before_fit(self, algo: "BaseAlgorithm", *args, **kwargs) -> None:
    """Retrieves and diplays in a single plot the input images from all tasks of the benchmark that the algorithm
    has been initialized with. The final plot is saved via the loggers.

    Args:
        algo (BaseAlgorithm): The BaseAlgorithm instance.
    """
    datasets = algo.benchmark.trains
    num_tasks = algo.num_tasks

    samples = []
    for dataset in datasets.values():
        task_samples = self.select_random_samples(dataset)
        samples.append(task_samples)

    s = 2
    figure, axes = plt.subplots(
        nrows=num_tasks,
        ncols=self.samples_per_task,
        figsize=(s * self.samples_per_task, s * num_tasks),
    )

    for i, task_samples in enumerate(samples):
        for j, (x, y, t) in enumerate(task_samples):
            if x.dim() == 2:
                x = x.unsqueeze(0)
            axes[i][j].imshow(x.permute(1, 2, 0))
            axes[i][j].title.set_text(f"t={t}: y={y}")

    plt.setp(axes, xticks=[], yticks=[])
    figure.subplots_adjust(wspace=0.5)

    # save the plot via the algorithm loggers
    algo.log_figure(name="input/viz", figure=figure)

select_random_samples(dataset)

Selects a prefefined number of samples per each CL dataset. Each task corresponds to a different dataset.

Parameters:

Name Type Description Default
dataset torch.data.utils.Dataset

The PyTorch Datatet.

required

Returns:

Type Description
List[torch.Tensor]

List[torch.Tensor]: The Tensors corresponding to the selected input samples.

Source code in sequel/utils/callbacks/input_visualization_callback.py
25
26
27
28
29
30
31
32
33
34
35
36
def select_random_samples(self, dataset: torch.utils.data.Dataset) -> List[torch.Tensor]:
    """Selects a prefefined number of samples per each CL dataset. Each task corresponds to a different dataset.

    Args:
        dataset (torch.data.utils.Dataset): The PyTorch Datatet.

    Returns:
        List[torch.Tensor]: The Tensors corresponding to the selected input samples.
    """
    indices = np.random.choice(len(dataset), self.samples_per_task, replace=False)
    samples = [dataset[i] for i in indices]
    return samples