The equivalent PyTorch implementation is SI in Pytorch.
References
[1] Zenke, F., Poole, B. & Ganguli, S. Continual Learning Through Synaptic Intelligence. in Proceedings of the
34th International Conference on Machine Learning, ICML 2017.
classSI(JaxRegularizationBaseAlgorithm):"""Synaptic Intelligence Algorithm. The equivalent PyTorch implementation is [`SI in Pytorch`][sequel.algos.pytorch.si.SI]. References: [1] Zenke, F., Poole, B. & Ganguli, S. Continual Learning Through Synaptic Intelligence. in Proceedings of the 34th International Conference on Machine Learning, ICML 2017. """def__init__(self,si_lambda:float=1.0,xi:float=0.1,*args,**kwargs):super().__init__(regularization_coefficient=si_lambda,*args,**kwargs)self.xi=xiself.w=jax.tree_map(lambdaa:0*a,self.state.params)def__repr__(self)->str:returnf"SI(si_lambda={self.regularization_coefficient}, xi={self.xi})"defcalculate_parameter_importance(self):ifself.task_counter==1:importance=jax.tree_map(lambdax:0*x,self.state.params)else:importance=self.importancedelta=jax.tree_map(lambdaw_cur,w_old:w_cur-w_old,self.state.params,self.old_params)importance=jax.tree_map(lambdai,w,dt:i+w/(dt**2+self.xi),importance,self.w,delta)self.w=jax.tree_map(lambdax:0*x,self.state.params)returnimportancedefon_before_training_step(self,*args,**kwargs):self.prev_params=copy.deepcopy(self.state.params)# @partial(jax.jit, static_argnums=(0,))defon_after_training_step(self,*args,**kwargs):grads=self.batch_outputs["grads"]delta=jax.tree_map(lambdaw_cur,w_old:w_cur-w_old,self.state.params,self.prev_params)self.w=jax.tree_map(lambdaw,g,d:w-g*d,self.w,grads,delta)defon_after_training_task(self,*args,**kwargs):self.old_params=copy.deepcopy(self.state.params)self.importance=self.calculate_parameter_importance()