Behrooz Tahmasebi, Ashkan Soleymani, Dara Bahri, Stefanie Jegelka, Patrick Jaillet
Sharpness-Aware Minimization (SAM) (Foret et al, 2021) is an optimisation procedure that aims to improve the generalisation of trained models, by biasing towards flatter minima in the loss landscape. [1]
Specifically, SAM works by minimising the maximum loss L(x) within a local neighbourhood of the model parameters. This has been shown to be equivalent to minimising the largest eigenvalue of the Hessian matrix H(x) on the zero-loss manifold. A variant of this uses the average loss instead of the maximum – this corresponds to minimising the trace of H(x) (and thus the average eigenvalue). These functions of the Hessian can be considered measures of sharpness.
However, these particular sharpness measures have shortcomings. For instance, a loss of L(a, b) = a2 – b2 has tr(H) = 0 everywhere, so the average-based sharpness objective becomes meaningless. The authors also provide examples where L(x) is scale-invariant but the sharpness measures are not.
Using the average-based objective as a starting point, the authors define a generalised class of sharpness measures S(x) that are functions of the Hessian. They then prove that this class of sharpness measures is universal for functions of Hessian eigenvalues, as well as for arbitrary functions of the Hessian.
They then provide an objective function that only relies on zeroth-order information about the training loss, with an explicit bias towards minimising S(x), as well as the full generalised SAM algorithm. They provide the form of the original maximum and average-based measures under this parameterisation, as well as new measures which use the Frobenius norm and determinant of the Hessian to address the aforementioned problems with saddle points and scale-invariance.
The authors then demonstrate that these SAM variants are competitive with the original ones on various vision datasets, and can outperform in certain scenarios, e.g. with training data is limited, and in the presence of label noise. It would be interesting to see how these variants compare across different datasets and model architectures.
[1] Sharpness-Aware Minimization for Efficiently Improving Generalization
A Universal Class of Sharpness-Aware Minimization Algorithms