Where to find Grokking in LLM Pretraining? Monitor Memorization-to-Generalization without Test
Abstract
Grokking, or continued test performance improvement after training loss convergence, is observed during pretraining of a large language model, showcasing a memorization-to-generalization process.
Grokking, i.e., test performance keeps improving long after training loss converged, has been recently witnessed in neural network training, making the mechanism of generalization and other emerging capabilities such as reasoning mysterious. While prior studies usually train small models on a few toy or highly-specific tasks for thousands of epochs, we conduct the first study of grokking on checkpoints during one-pass pretraining of a 7B large language model (LLM), i.e., OLMoE. We compute the training loss and evaluate generalization on diverse benchmark tasks, including math reasoning, code generation, and commonsense/domain-specific knowledge retrieval tasks. Our study, for the first time, verifies that grokking still happens in the pretraining of large-scale foundation models, though different data may enter grokking stages asynchronously. We further demystify grokking's "emergence of generalization" by investigating LLM internal dynamics. Specifically, we find that training samples' pathways (i.e., expert choices across layers) evolve from random, instance-specific to more structured and shareable between samples during grokking. Also, the complexity of a sample's pathway reduces despite the converged loss. These indicate a memorization-to-generalization conversion, providing a mechanistic explanation of delayed generalization. In the study, we develop two novel metrics to quantify pathway distance and the complexity of a single pathway. We show their ability to predict the generalization improvement on diverse downstream tasks. They are efficient, simple to compute and solely dependent on training data. Hence, they have practical value for pretraining, enabling us to monitor the generalization performance without finetuning and test. Theoretically, we show that more structured pathways reduce model complexity and improve the generalization bound.
Community
Main findings:
Grokking still occurs during the one-pass pretraining of practical-scale LLMs but it is local and asynchronous for different data groups/domains, unlike global grokking for all data in previous works.
Grokking's memorization-to-generalization mechanism can be explained by the dynamics of LLMs' internal states such as MoE pathways, whose similarity between training samples and consistency across layers increase during grokking. They indicate a smarter memorization, using more shared knowledge across samples and lower complexity per sample leads to better generalization. This is also supported by a theoretical connection between pathway complexity and a generalization bound.
The two novel metrics we developed to measure pathway complexity are computed on training data only without requiring any test/validation set or model finetuning, providing a practically efficient tool to monitor the generalization during LLM pretraining.
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper