NeurIPS 2022: Paper review #7
G-Research were headline sponsors at NeurIPS 2022, in New Orleans.
ML is a fast-evolving discipline; attending conferences like NeurIPS and keeping up-to-date with the latest developments is key to the success of our quantitative researchers and machine learning engineers.
Our NeurIPS 2022 paper review series gives you the opportunity to hear about the research and papers that our quants and ML engineers found most interesting from the conference.
Here, Maxime R, Senior Quantitative Researcher at G-Research, discusses two papers from NeurIPS:
- On the Parameterization and Initialization of Diagonal State Space Models
- Sharpness-Aware Training for Free
Albert Gu, Ankit Gupta, Karan Goel, Christopher Ré
SSM aims to create a hidden representation x(t) of a sequence u(t) of the form x'(t) = A x(t) + B u(t), and then forecast y(t) = C x(t). To make the training of such models less prone to gradient explosion, or vanishing like standard RNNs, and to make them able to capture long-range dependencies, S4 initialises the matrix A with a HiPPO matrix (arXiv:2008.07669), which is rooted in polynomial approximation of time series.
To leverage the GPU acceleration, they come up with a complex but efficient algorithm to formulate the problem as a convolution, and they manage to fit the model on sequences as long as 16,000 steps. They achieve SoTA results on multiple sequential datasets, especially the Long Range Arena benchmark, while being much faster at training and inference.
This paper, S4D, relaxes the initialisation of A to make the algorithm of S4 more understandable, easier to implement and more computationally efficient, while proving that the formulation is equivalent to S4 asymptotically and achieving similar results on the benchmarks.
The code is available on Github and is usable out-of-the-box.
Jiawei Du, Daquan Zhou , Jiashi Feng , Vincent Y. F. Tan , Joey Tianyi Zhou
This paper builds on the Sharpness-Aware Minimization (SAM) strategy, which tries to force the neural network to converge to a flat minimum, in order to help with the generalisation.
While various SAM strategies report a higher test performance, they come with a doubling of the training time. This is because of the addition in the loss of an estimate of the flatness of the loss around the current point, which involves computing gradients for all the weights and rerunning a forward pass.
This paper aims to provide a method that keeps the convergence properties of SAM while having minimal extra overhead.
The authors first offer a method, Sharpness-Aware Training for Free (SAF), which stores the output of the network during the last E epochs. It adds to the loss a term that compares the output on the same batch E epochs earlier and the current output, using a KL divergence.
They show that this is a good approximation of the flatness of the loss, however this is infeasible for larger datasets as the estimate would be poor and the memory overhead prohibitive.
They therefore offer another strategy, Memory-Efficient Sharpness Aware training (MESA). MESA keeps an Exponentially Weighted Moving Average (EWMA) of the weights of the network. It then compares the output of the batch with the current weights with the output with the EWMA weights, using a KL divergence.
This method forces a duplicated forward pass but no backwards duplication and is hence much less prohibitive than SAM while keeping its performance gains. They show that the model achieves comparable results to SAM on CIFAR 10/100, and on ImageNet.