Sharpness-Aware Minimization Leads to Low-Rank Features
Maksym Andriushchenko, Dara Bahri, Hossein Mobahi, Nicolas Flammarion
In overparametrised neural networks, sharpness of minima has been observed to correlate negatively with the generalisation error of the model. Sharpness-aware minimisation (SAM) is a recent algorithm that introduces an explicit sharpness penalty to the optimisation objective which has been shown to improve model performance.
In this paper, the authors investigate the effect that SAM has on the features of the model. They demonstrate that SAM reduces the feature rank at different layers, as measured by the number of principal components that are needed to capture 99% of the variance, compared to networks that are trained using standard minimisation algorithms. This can for instance be used to reduce the dimensionality of the feature space, improving the performance of downstream tasks. In contrast, the authors found that directly imposing a lower feature rank on the model itself did not lead to improved generalisation. This suggests that the low rank is a useful side effect but not a full explanation of the benefits of SAM.
To further understand the mechanism behind this effect, the authors study a two-layer ReLU network. They show, both experimentally and theoretically, that SAM decreases pre-activation values within the network. This, in turn, reduces the number of non-zero activations and results in the observed low rank of the features.
Sharpness-Aware Minimization Leads to Low-Rank Features