Andrea Montanari, Pierfrancesco Urbani
It’s often observed of deep and high-capacity that if one trains the model for long enough or in a certain way, it ends up being able to perfectly interpolate the training data, i.e. with zero error, with the resultant fitted models being of widely varying ability to generalise to unseen test data.
This implies the manner in which a model is trained is very important, and that architecture and training data alone are insufficient for determining out-of-fit performance.
Here the authors focus on a wide two-layer neural network (a simple setting but one which still exhibits feature learning during early training and with capacity to have zero error on training data) and explain dynamics throughout the training process in a unified framework that explains empirical observations within the training. Non-monotone test error is a consequence of this framework, i.e. they identify timescales during training within which the network is underfitting or overfitting.
This framework also explains how the points where feature learning stops (test error stops decreasing) and overfitting starts (test error starts increasing) depend on network size and the scale of parameter initialisation in the final layer (of interest given that it disagrees with what is often common practice) in a clear way, explained through the prism of complexity growth of the model throughout training.
This provides a single framework explaining both training and overfitting in a nonasymptotic way, yielding a more comprehensive understanding of training dynamics of the two-layer networks they consider. In turn, since many of these same empirical observations are often made for more complicated models, we can hope that the intuition provided is extendible to those settings.
It will be exciting to see if future research can apply these techniques to more complex models, promising both confirmation of existing understanding and even provision of novel insights.
Dynamical Decoupling of Generalisation and Overfitting in Large Two-Layer Networks