Skip to content

MC-SGD

MCSGD

Bases: JaxBaseAlgorithm

MC-SGD: Mode Connectivity-Stochastic Gradient Descent. Inherits from BaseAlgorithm.

The equivalent PyTorch implementation is MCSGD in Pytorch.

References

[1] Mirzadeh, S.-I., Farajtabar, M., Görür, D., Pascanu, R. & Ghasemzadeh, H. Linear Mode Connectivity in Multitask and Continual Learning. in 9th International Conference on Learning Representations, ICLR 2021.

Source code in sequel/algos/jax/mcsgd.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class MCSGD(JaxBaseAlgorithm):
    """MC-SGD: Mode Connectivity-Stochastic Gradient Descent. Inherits from BaseAlgorithm.

    The equivalent PyTorch implementation is [`MCSGD in Pytorch`][sequel.algos.pytorch.mcsgd.MCSGD].

    References:
        [1] Mirzadeh, S.-I., Farajtabar, M., Görür, D., Pascanu, R. & Ghasemzadeh, H. Linear Mode Connectivity in
            Multitask and Continual Learning. in 9th International Conference on Learning Representations, ICLR 2021.
    """

    state: TrainState

    def __init__(
        self,
        per_task_memory_samples: int = 100,
        memory_group_by: Literal["task", "class"] = "task",
        lmc_policy="offline",
        lmc_interpolation="linear",
        lmc_lr=0.05,
        lmc_momentum=0.8,
        lmc_batch_size=64,
        lmc_init_position=0.1,
        lmc_line_samples=10,
        lmc_epochs=1,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.memory = MemoryMechanism(per_task_memory_samples=per_task_memory_samples, groupby=memory_group_by)
        self.w_bar_prev = None
        self.w_hat_curr = None

        # parse init arguments
        self.per_task_memory_samples = per_task_memory_samples
        self.lmc_policy = lmc_policy
        self.lmc_interpolation = lmc_interpolation
        self.lmc_lr = lmc_lr
        self.lmc_momentum = lmc_momentum
        self.lmc_batch_size = lmc_batch_size
        self.lmc_init_position = lmc_init_position
        self.lmc_line_samples = lmc_line_samples
        self.lmc_epochs = lmc_epochs

    def __repr__(self) -> str:
        return (
            "MCSGD("
            + f"per_task_memory_samples={self.per_task_memory_samples}, "
            + f"policy={self.lmc_policy}, "
            + f"interpolation={self.lmc_interpolation}, "
            + f"lr={self.lmc_lr}, "
            + f"momentum={self.lmc_momentum}, "
            + f"batch_size={self.lmc_batch_size}, "
            + f"init_position={self.lmc_init_position}, "
            + f"line_samples={self.lmc_line_samples}, "
            + f"epochs={self.lmc_epochs}"
            + ")"
        )

    def calculate_line_loss(self, w_start, w_end, loader):
        line_samples = np.arange(0.0, 1.01, 1.0 / float(self.lmc_line_samples))
        grads = tree_map(lambda x: 0 * x, w_start)
        for t in tqdm(line_samples, desc="Line samples"):
            params = tree_map(lambda a, b: a + (b - a) * t, w_start, w_end)
            g = self.calculate_point_loss(params, loader)
            grads = tree_map(lambda a, b: a + b, grads, g)
        return grads

    @partial(jax.jit, static_argnums=(0,))
    def simple_training_step(self, params, x, y, t, step):
        grad_fn = jax.value_and_grad(self.cross_entropy, has_aux=True, allow_int=True)
        (loss, logits), grads = grad_fn(params, x, y, t, self.is_training, step=step)
        return grads

    def calculate_point_loss(self, params, loader):
        total_count = 0.0
        grads = tree_map(lambda x: 0 * x, params)
        for batch in loader:
            self.unpack_batch(batch)
            g = self.simple_training_step(params, self.x, self.y, self.t, self.step_counter)
            grads = tree_map(lambda a, b: a + b, grads, g)
            total_count += self.bs

        return tree_map(lambda a: a / total_count, grads)

    def find_connected_minima(self, task):
        bs = self.lmc_batch_size
        loader_curr = self.benchmark.train_dataloader_subset(
            task, batch_size=bs, subset_size=self.per_task_memory_samples
        )
        loader_prev = self.benchmark.memory_dataloader(task, batch_size=bs, return_infinite_stream=False)

        params = tree_map(lambda a, b: a + (b - a) * self.lmc_init_position, self.w_bar_prev, self.w_hat_curr)
        tx = optax.sgd(learning_rate=self.lmc_lr, momentum=self.lmc_momentum)
        state = TrainState.create(apply_fn=self.apply_fn, params=params, tx=tx)

        grads_prev = self.calculate_line_loss(self.w_bar_prev, state.params, loader_prev)
        grads_curr = self.calculate_line_loss(self.w_hat_curr, state.params, loader_curr)

        grads = tree_map(lambda a, b: a + b, grads_prev, grads_curr)
        state = state.apply_gradients(grads=grads)
        return state

    def on_after_training_epoch(self, *args, **kwargs):
        self.w_hat_curr = copy.deepcopy(self.state.params)
        self.old_state = copy.deepcopy(self.state)

    def validate_algorithm_on_all_tasks(self) -> Dict[str, float]:
        if self.task_counter == 1:
            super().validate_algorithm_on_all_tasks()

    def on_after_validating_algorithm_on_all_tasks_callbacks(self):
        if self.task_counter == 1:
            return super().on_after_validating_algorithm_on_all_tasks_callbacks()

    def on_after_training_task(self, *args, **kwargs):
        self.memory.update_memory(self)
        if self.task_counter > 1:
            # save the backbone obtained from the mode-connectivity updates
            # as the Multi-Task approximate solution

            self.w_bar_prev = self.find_connected_minima(self.task_counter).params
            # perform the validation with the weights obtained after the mode-connectivity updates
            self.state = self.state.replace(params=self.w_bar_prev)
            super().on_before_validating_algorithm_on_all_tasks_callbacks()
            super().validate_algorithm_on_all_tasks()
            super().on_after_validating_algorithm_on_all_tasks_callbacks()
        else:
            self.w_bar_prev = copy.deepcopy(self.state.params)

        # revert the weights of the backbone to the Continual Learning solution
        self.state = copy.deepcopy(self.old_state)