To grok or not to grok: Disentangling generalization and memorization on corrupted algorithmic datasets
ORAL
Abstract
Robust generalization is a major challenge in deep learning -- it is often very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study a simple and interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider two-layer neural networks trained on modular arithmetic tasks, where ξ·100% of the labels are corrupted (i.e. some of the results of the modular operation in the training set are incorrect). This setup exhibits “grokking”, the phenomenon of delayed and sudden generalization.
In this setup, we show that (i) it is possible for the network to memorize the corrupted labels and achieve 100% generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on the corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve 100% accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is (“mechanistically”) interpretable, using Inverse Participation Ratio: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones.
In this setup, we show that (i) it is possible for the network to memorize the corrupted labels and achieve 100% generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on the corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve 100% accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is (“mechanistically”) interpretable, using Inverse Participation Ratio: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones.
* Our work at the University of Maryland was supported in part by NSF CAREER Award DMR- 2045181, Sloan Foundation and the Laboratory for Physical Sciences through the Condensed Matter Theory Center
–
Presenters
-
Darshil H Doshi
University of Maryland, College Park
Authors
-
Darshil H Doshi
University of Maryland, College Park
-
Aritra Das
University of Maryland College Park
-
Tianyu He
University of Maryland, College Park
-
Andrey Gromov
University of Maryland, College Park