Get trending papers in your email inbox once a day!
Get trending papers in your email inbox!
SubscribeMemoryBank: Enhancing Large Language Models with Long-Term Memory
Revolutionary advancements in Large Language Models have drastically reshaped our interactions with artificial intelligence systems. Despite this, a notable hindrance remains-the deficiency of a long-term memory mechanism within these models. This shortfall becomes increasingly evident in situations demanding sustained interaction, such as personal companion systems and psychological counseling. Therefore, we propose MemoryBank, a novel memory mechanism tailored for LLMs. MemoryBank enables the models to summon relevant memories, continually evolve through continuous memory updates, comprehend, and adapt to a user personality by synthesizing information from past interactions. To mimic anthropomorphic behaviors and selectively preserve memory, MemoryBank incorporates a memory updating mechanism, inspired by the Ebbinghaus Forgetting Curve theory, which permits the AI to forget and reinforce memory based on time elapsed and the relative significance of the memory, thereby offering a human-like memory mechanism. MemoryBank is versatile in accommodating both closed-source models like ChatGPT and open-source models like ChatGLM. We exemplify application of MemoryBank through the creation of an LLM-based chatbot named SiliconFriend in a long-term AI Companion scenario. Further tuned with psychological dialogs, SiliconFriend displays heightened empathy in its interactions. Experiment involves both qualitative analysis with real-world user dialogs and quantitative analysis with simulated dialogs. In the latter, ChatGPT acts as users with diverse characteristics and generates long-term dialog contexts covering a wide array of topics. The results of our analysis reveal that SiliconFriend, equipped with MemoryBank, exhibits a strong capability for long-term companionship as it can provide emphatic response, recall relevant memories and understand user personality.
Do Your Best and Get Enough Rest for Continual Learning
According to the forgetting curve theory, we can enhance memory retention by learning extensive data and taking adequate rest. This means that in order to effectively retain new knowledge, it is essential to learn it thoroughly and ensure sufficient rest so that our brain can memorize without forgetting. The main takeaway from this theory is that learning extensive data at once necessitates sufficient rest before learning the same data again. This aspect of human long-term memory retention can be effectively utilized to address the continual learning of neural networks. Retaining new knowledge for a long period of time without catastrophic forgetting is the critical problem of continual learning. Therefore, based on Ebbinghaus' theory, we introduce the view-batch model that adjusts the learning schedules to optimize the recall interval between retraining the same samples. The proposed view-batch model allows the network to get enough rest to learn extensive knowledge from the same samples with a recall interval of sufficient length. To this end, we specifically present two approaches: 1) a replay method that guarantees the optimal recall interval, and 2) a self-supervised learning that acquires extensive knowledge from a single training sample at a time. We empirically show that these approaches of our method are aligned with the forgetting curve theory, which can enhance long-term memory. In our experiments, we also demonstrate that our method significantly improves many state-of-the-art continual learning methods in various protocols and scenarios. We open-source this project at https://github.com/hankyul2/ViewBatchModel.
Causal Estimation of Memorisation Profiles
Understanding memorisation in language models has practical and societal implications, e.g., studying models' training dynamics or preventing copyright infringements. Prior work defines memorisation as the causal effect of training with an instance on the model's ability to predict that instance. This definition relies on a counterfactual: the ability to observe what would have happened had the model not seen that instance. Existing methods struggle to provide computationally efficient and accurate estimates of this counterfactual. Further, they often estimate memorisation for a model architecture rather than for a specific model instance. This paper fills an important gap in the literature, proposing a new, principled, and efficient method to estimate memorisation based on the difference-in-differences design from econometrics. Using this method, we characterise a model's memorisation profile--its memorisation trends across training--by only observing its behaviour on a small set of instances throughout training. In experiments with the Pythia model suite, we find that memorisation (i) is stronger and more persistent in larger models, (ii) is determined by data order and learning rate, and (iii) has stable trends across model sizes, thus making memorisation in larger models predictable from smaller ones.
The Joint Effect of Task Similarity and Overparameterization on Catastrophic Forgetting -- An Analytical Model
In continual learning, catastrophic forgetting is affected by multiple aspects of the tasks. Previous works have analyzed separately how forgetting is affected by either task similarity or overparameterization. In contrast, our paper examines how task similarity and overparameterization jointly affect forgetting in an analyzable model. Specifically, we focus on two-task continual linear regression, where the second task is a random orthogonal transformation of an arbitrary first task (an abstraction of random permutation tasks). We derive an exact analytical expression for the expected forgetting - and uncover a nuanced pattern. In highly overparameterized models, intermediate task similarity causes the most forgetting. However, near the interpolation threshold, forgetting decreases monotonically with the expected task similarity. We validate our findings with linear regression on synthetic data, and with neural networks on established permutation task benchmarks.
Emergent and Predictable Memorization in Large Language Models
Memorization, or the tendency of large language models (LLMs) to output entire sequences from their training data verbatim, is a key concern for safely deploying language models. In particular, it is vital to minimize a model's memorization of sensitive datapoints such as those containing personal identifiable information (PII). The prevalence of such undesirable memorization can pose issues for model trainers, and may even require discarding an otherwise functional model. We therefore seek to predict which sequences will be memorized before a large model's full train-time by extrapolating the memorization behavior of lower-compute trial runs. We measure memorization of the Pythia model suite and plot scaling laws for forecasting memorization, allowing us to provide equi-compute recommendations to maximize the reliability (recall) of such predictions. We additionally provide further novel discoveries on the distribution of memorization scores across models and data. We release all code and data necessary to reproduce the results in this paper at https://github.com/EleutherAI/pythia
The Ideal Continual Learner: An Agent That Never Forgets
The goal of continual learning is to find a model that solves multiple learning tasks which are presented sequentially to the learner. A key challenge in this setting is that the learner may forget how to solve a previous task when learning a new task, a phenomenon known as catastrophic forgetting. To address this challenge, many practical methods have been proposed, including memory-based, regularization-based, and expansion-based methods. However, a rigorous theoretical understanding of these methods remains elusive. This paper aims to bridge this gap between theory and practice by proposing a new continual learning framework called Ideal Continual Learner (ICL), which is guaranteed to avoid catastrophic forgetting by construction. We show that ICL unifies multiple well-established continual learning methods and gives new theoretical insights into the strengths and weaknesses of these methods. We also derive generalization bounds for ICL which allow us to theoretically quantify how rehearsal affects generalization. Finally, we connect ICL to several classic subjects and research topics of modern interest, which allows us to make historical remarks and inspire future directions.
How Much Can We Forget about Data Contamination?
The leakage of benchmark data into the training data has emerged as a significant challenge for evaluating the capabilities of large language models (LLMs). In this work, we challenge the common assumption that small-scale contamination renders benchmark evaluations invalid. First, we experimentally quantify the magnitude of benchmark overfitting based on scaling along three dimensions: The number of model parameters (up to 1.6B), the number of times an example is seen (up to 144), and the number of training tokens (up to 40B). If model and data follow the Chinchilla scaling laws, minor contamination indeed leads to overfitting. At the same time, even 144 times of contamination can be forgotten if the training data is scaled beyond five times Chinchilla, a regime characteristic of many modern LLMs. Continual pre-training of OLMo-7B corroborates these results. Next, we study the impact of the weight decay parameter on example forgetting, showing that empirical forgetting occurs faster than the cumulative weight decay. This allows us to gauge the degree of example forgetting in large-scale training runs, indicating that many LLMs, including Lllama 3 405B, have forgotten the data seen at the beginning of training.
UER: A Heuristic Bias Addressing Approach for Online Continual Learning
Online continual learning aims to continuously train neural networks from a continuous data stream with a single pass-through data. As the most effective approach, the rehearsal-based methods replay part of previous data. Commonly used predictors in existing methods tend to generate biased dot-product logits that prefer to the classes of current data, which is known as a bias issue and a phenomenon of forgetting. Many approaches have been proposed to overcome the forgetting problem by correcting the bias; however, they still need to be improved in online fashion. In this paper, we try to address the bias issue by a more straightforward and more efficient method. By decomposing the dot-product logits into an angle factor and a norm factor, we empirically find that the bias problem mainly occurs in the angle factor, which can be used to learn novel knowledge as cosine logits. On the contrary, the norm factor abandoned by existing methods helps remember historical knowledge. Based on this observation, we intuitively propose to leverage the norm factor to balance the new and old knowledge for addressing the bias. To this end, we develop a heuristic approach called unbias experience replay (UER). UER learns current samples only by the angle factor and further replays previous samples by both the norm and angle factors. Extensive experiments on three datasets show that UER achieves superior performance over various state-of-the-art methods. The code is in https://github.com/FelixHuiweiLin/UER.
Theory on Forgetting and Generalization of Continual Learning
Continual learning (CL), which aims to learn a sequence of tasks, has attracted significant recent attention. However, most work has focused on the experimental performance of CL, and theoretical studies of CL are still limited. In particular, there is a lack of understanding on what factors are important and how they affect "catastrophic forgetting" and generalization performance. To fill this gap, our theoretical analysis, under overparameterized linear models, provides the first-known explicit form of the expected forgetting and generalization error. Further analysis of such a key result yields a number of theoretical explanations about how overparameterization, task similarity, and task ordering affect both forgetting and generalization error of CL. More interestingly, by conducting experiments on real datasets using deep neural networks (DNNs), we show that some of these insights even go beyond the linear models and can be carried over to practical setups. In particular, we use concrete examples to show that our results not only explain some interesting empirical observations in recent studies, but also motivate better practical algorithm designs of CL.
Recognition, recall, and retention of few-shot memories in large language models
The training of modern large language models (LLMs) takes place in a regime where most training examples are seen only a few times by the model during the course of training. What does a model remember about such examples seen only a few times during training and how long does that memory persist in the face of continuous training with new examples? Here, we investigate these questions through simple recognition, recall, and retention experiments with LLMs. In recognition experiments, we ask if the model can distinguish the seen example from a novel example; in recall experiments, we ask if the model can correctly recall the seen example when cued by a part of it; and in retention experiments, we periodically probe the model's memory for the original examples as the model is trained continuously with new examples. We find that a single exposure is generally sufficient for a model to achieve near perfect accuracy even in very challenging recognition experiments. We estimate that the recognition performance of even small language models easily exceeds human recognition performance reported in similar experiments with humans (Shepard, 1967). Achieving near perfect recall takes more exposures, but most models can do it in just 3 exposures. The flip side of this remarkable capacity for fast learning is that precise memories are quickly overwritten: recall performance for the original examples drops steeply over the first 10 training updates with new examples, followed by a more gradual decline. Even after 100K updates, however, some of the original examples are still recalled near perfectly. A qualitatively similar retention pattern has been observed in human long-term memory retention studies before (Bahrick, 1984). Finally, recognition is much more robust to interference than recall and memory for natural language sentences is generally superior to memory for stimuli without structure.
Heterogeneous Forgetting Compensation for Class-Incremental Learning
Class-incremental learning (CIL) has achieved remarkable successes in learning new classes consecutively while overcoming catastrophic forgetting on old categories. However, most existing CIL methods unreasonably assume that all old categories have the same forgetting pace, and neglect negative influence of forgetting heterogeneity among different old classes on forgetting compensation. To surmount the above challenges, we develop a novel Heterogeneous Forgetting Compensation (HFC) model, which can resolve heterogeneous forgetting of easy-to-forget and hard-to-forget old categories from both representation and gradient aspects. Specifically, we design a task-semantic aggregation block to alleviate heterogeneous forgetting from representation aspect. It aggregates local category information within each task to learn task-shared global representations. Moreover, we develop two novel plug-and-play losses: a gradient-balanced forgetting compensation loss and a gradient-balanced relation distillation loss to alleviate forgetting from gradient aspect. They consider gradient-balanced compensation to rectify forgetting heterogeneity of old categories and heterogeneous relation consistency. Experiments on several representative datasets illustrate effectiveness of our HFC model. The code is available at https://github.com/JiahuaDong/HFC.
Center Loss Regularization for Continual Learning
The ability to learn different tasks sequentially is essential to the development of artificial intelligence. In general, neural networks lack this capability, the major obstacle being catastrophic forgetting. It occurs when the incrementally available information from non-stationary data distributions is continually acquired, disrupting what the model has already learned. Our approach remembers old tasks by projecting the representations of new tasks close to that of old tasks while keeping the decision boundaries unchanged. We employ the center loss as a regularization penalty that enforces new tasks' features to have the same class centers as old tasks and makes the features highly discriminative. This, in turn, leads to the least forgetting of already learned information. This method is easy to implement, requires minimal computational and memory overhead, and allows the neural network to maintain high performance across many sequentially encountered tasks. We also demonstrate that using the center loss in conjunction with the memory replay outperforms other replay-based strategies. Along with standard MNIST variants for continual learning, we apply our method to continual domain adaptation scenarios with the Digits and PACS datasets. We demonstrate that our approach is scalable, effective, and gives competitive performance compared to state-of-the-art continual learning methods.
Model Editing at Scale leads to Gradual and Catastrophic Forgetting
Editing knowledge in large language models is an attractive capability to have which allows us to correct incorrectly learnt facts during pre-training, as well as update the model with an ever-growing list of new facts. While existing model editing techniques have shown promise, they are usually evaluated using metrics for reliability, specificity and generalization over one or few edits. We argue that for model editing to have practical utility, we must be able to make multiple edits to the same model. With this in mind, we evaluate the current model editing methods at scale, focusing on two state of the art methods: ROME and MEMIT. We find that as the model is edited sequentially with multiple facts, it continually forgets previously edited facts and the ability to perform downstream tasks. This forgetting happens in two phases -- an initial gradual but progressive forgetting phase followed by abrupt or catastrophic forgetting phase. Both gradual and catastrophic forgetting limit the usefulness of model editing methods at scale -- the former making model editing less effective as multiple edits are made to the model while the latter caps the scalability of such model editing methods. Our analysis also highlights other key limitations of ROME and MEMIT at scale. With our work, we push for the development and evaluation of model editing methods keeping scalability in mind.
Spurious Forgetting in Continual Learning of Language Models
Recent advancements in large language models (LLMs) reveal a perplexing phenomenon in continual learning: despite extensive training, models experience significant performance declines, raising questions about task alignment and underlying knowledge retention. This study first explores the concept of "spurious forgetting", proposing that such performance drops often reflect a decline in task alignment rather than true knowledge loss. Through controlled experiments with a synthesized dataset, we investigate the dynamics of model performance during the initial training phases of new tasks, discovering that early optimization steps can disrupt previously established task alignments. Our theoretical analysis connects these shifts to orthogonal updates in model weights, providing a robust framework for understanding this behavior. Ultimately, we introduce a Freezing strategy that fix the bottom layers of the model, leading to substantial improvements in four continual learning scenarios. Our findings underscore the critical distinction between task alignment and knowledge retention, paving the way for more effective strategies in continual learning.
Large Continual Instruction Assistant
Continual Instruction Tuning (CIT) is adopted to continually instruct Large Models to follow human intent data by data. It is observed that existing gradient update would heavily destroy the performance on previous datasets during CIT process. Instead, Exponential Moving Average (EMA), owns the ability to trace previous parameters, which can aid in decreasing forgetting. Nonetheless, its stable balance weight fails to deal with the ever-changing datasets, leading to the out-of-balance between plasticity and stability. In this paper, we propose a general continual instruction tuning framework to address the challenge. Starting from the trade-off prerequisite and EMA update, we propose the plasticity and stability ideal condition. Based on Taylor expansion in the loss function, we find the optimal balance weight can be automatically determined by the gradients and learned parameters. Therefore, we propose a stable-plasticity balanced coefficient to avoid knowledge interference. Based on the semantic similarity of the instructions, we can determine whether to retrain or expand the training parameters and allocate the most suitable parameters for the testing instances. Extensive experiments across multiple continual instruction tuning benchmarks demonstrate that our approach not only enhances anti-forgetting capabilities but also significantly improves overall continual tuning performance. Our code is available at https://github.com/JingyangQiao/CoIN.
An Empirical Study of Example Forgetting during Deep Neural Network Learning
Inspired by the phenomenon of catastrophic forgetting, we investigate the learning dynamics of neural networks as they train on single classification tasks. Our goal is to understand whether a related phenomenon occurs when data does not undergo a clear distributional shift. We define a `forgetting event' to have occurred when an individual training example transitions from being classified correctly to incorrectly over the course of learning. Across several benchmark data sets, we find that: (i) certain examples are forgotten with high frequency, and some not at all; (ii) a data set's (un)forgettable examples generalize across neural architectures; and (iii) based on forgetting dynamics, a significant fraction of examples can be omitted from the training data set while still maintaining state-of-the-art generalization performance.
ACU: Analytic Continual Unlearning for Efficient and Exact Forgetting with Privacy Preservation
The development of artificial intelligence demands that models incrementally update knowledge by Continual Learning (CL) to adapt to open-world environments. To meet privacy and security requirements, Continual Unlearning (CU) emerges as an important problem, aiming to sequentially forget particular knowledge acquired during the CL phase. However, existing unlearning methods primarily focus on single-shot joint forgetting and face significant limitations when applied to CU. First, most existing methods require access to the retained dataset for re-training or fine-tuning, violating the inherent constraint in CL that historical data cannot be revisited. Second, these methods often suffer from a poor trade-off between system efficiency and model fidelity, making them vulnerable to being overwhelmed or degraded by adversaries through deliberately frequent requests. In this paper, we identify that the limitations of existing unlearning methods stem fundamentally from their reliance on gradient-based updates. To bridge the research gap at its root, we propose a novel gradient-free method for CU, named Analytic Continual Unlearning (ACU), for efficient and exact forgetting with historical data privacy preservation. In response to each unlearning request, our ACU recursively derives an analytical (i.e., closed-form) solution in an interpretable manner using the least squares method. Theoretical and experimental evaluations validate the superiority of our ACU on unlearning effectiveness, model fidelity, and system efficiency.
A Multi-Power Law for Loss Curve Prediction Across Learning Rate Schedules
Training large models is both resource-intensive and time-consuming, making it crucial to understand the quantitative relationship between model performance and hyperparameters. In this paper, we present an empirical law that describes how the pretraining loss of large language models evolves under different learning rate schedules, such as constant, cosine, and step decay schedules. Our proposed law takes a multi-power form, combining a power law based on the sum of learning rates and additional power laws to account for a loss reduction effect induced by learning rate decay. We extensively validate this law on various model sizes and architectures, and demonstrate that after fitting on a few learning rate schedules, the law accurately predicts the loss curves for unseen schedules of different shapes and horizons. Moreover, by minimizing the predicted final pretraining loss across learning rate schedules, we are able to find a schedule that outperforms the widely used cosine learning rate schedule. Interestingly, this automatically discovered schedule bears some resemblance to the recently proposed Warmup-Stable-Decay (WSD) schedule (Hu et al, 2024) but achieves a slightly lower final loss. We believe these results could offer valuable insights for understanding the dynamics of pretraining and designing learning rate schedules to improve efficiency.
Continual evaluation for lifelong learning: Identifying the stability gap
Time-dependent data-generating distributions have proven to be difficult for gradient-based training of neural networks, as the greedy updates result in catastrophic forgetting of previously learned knowledge. Despite the progress in the field of continual learning to overcome this forgetting, we show that a set of common state-of-the-art methods still suffers from substantial forgetting upon starting to learn new tasks, except that this forgetting is temporary and followed by a phase of performance recovery. We refer to this intriguing but potentially problematic phenomenon as the stability gap. The stability gap had likely remained under the radar due to standard practice in the field of evaluating continual learning models only after each task. Instead, we establish a framework for continual evaluation that uses per-iteration evaluation and we define a new set of metrics to quantify worst-case performance. Empirically we show that experience replay, constraint-based replay, knowledge-distillation, and parameter regularization methods are all prone to the stability gap; and that the stability gap can be observed in class-, task-, and domain-incremental learning benchmarks. Additionally, a controlled experiment shows that the stability gap increases when tasks are more dissimilar. Finally, by disentangling gradients into plasticity and stability components, we propose a conceptual explanation for the stability gap.
Scaling Laws for Forgetting When Fine-Tuning Large Language Models
We study and quantify the problem of forgetting when fine-tuning pre-trained large language models (LLMs) on a downstream task. We find that parameter-efficient fine-tuning (PEFT) strategies, such as Low-Rank Adapters (LoRA), still suffer from catastrophic forgetting. In particular, we identify a strong inverse linear relationship between the fine-tuning performance and the amount of forgetting when fine-tuning LLMs with LoRA. We further obtain precise scaling laws that show forgetting increases as a shifted power law in the number of parameters fine-tuned and the number of update steps. We also examine the impact of forgetting on knowledge, reasoning, and the safety guardrails trained into Llama 2 7B chat. Our study suggests that forgetting cannot be avoided through early stopping or by varying the number of parameters fine-tuned. We believe this opens up an important safety-critical direction for future research to evaluate and develop fine-tuning schemes which mitigate forgetting
SEFE: Superficial and Essential Forgetting Eliminator for Multimodal Continual Instruction Tuning
Multimodal Continual Instruction Tuning (MCIT) aims to enable Multimodal Large Language Models (MLLMs) to incrementally learn new tasks without catastrophic forgetting. In this paper, we explore forgetting in this context, categorizing it into superficial forgetting and essential forgetting. Superficial forgetting refers to cases where the model's knowledge may not be genuinely lost, but its responses to previous tasks deviate from expected formats due to the influence of subsequent tasks' answer styles, making the results unusable. By contrast, essential forgetting refers to situations where the model provides correctly formatted but factually inaccurate answers, indicating a true loss of knowledge. Assessing essential forgetting necessitates addressing superficial forgetting first, as severe superficial forgetting can obscure the model's knowledge state. Hence, we first introduce the Answer Style Diversification (ASD) paradigm, which defines a standardized process for transforming data styles across different tasks, unifying their training sets into similarly diversified styles to prevent superficial forgetting caused by style shifts. Building on this, we propose RegLoRA to mitigate essential forgetting. RegLoRA stabilizes key parameters where prior knowledge is primarily stored by applying regularization, enabling the model to retain existing competencies. Experimental results demonstrate that our overall method, SEFE, achieves state-of-the-art performance.
DualHSIC: HSIC-Bottleneck and Alignment for Continual Learning
Rehearsal-based approaches are a mainstay of continual learning (CL). They mitigate the catastrophic forgetting problem by maintaining a small fixed-size buffer with a subset of data from past tasks. While most rehearsal-based approaches study how to effectively exploit the knowledge from the buffered past data, little attention is paid to the inter-task relationships with the critical task-specific and task-invariant knowledge. By appropriately leveraging inter-task relationships, we propose a novel CL method named DualHSIC to boost the performance of existing rehearsal-based methods in a simple yet effective way. DualHSIC consists of two complementary components that stem from the so-called Hilbert Schmidt independence criterion (HSIC): HSIC-Bottleneck for Rehearsal (HBR) lessens the inter-task interference and HSIC Alignment (HA) promotes task-invariant knowledge sharing. Extensive experiments show that DualHSIC can be seamlessly plugged into existing rehearsal-based methods for consistent performance improvements, and also outperforms recent state-of-the-art regularization-enhanced rehearsal methods. Source code will be released.
Challenging Common Assumptions about Catastrophic Forgetting
Building learning agents that can progressively learn and accumulate knowledge is the core goal of the continual learning (CL) research field. Unfortunately, training a model on new data usually compromises the performance on past data. In the CL literature, this effect is referred to as catastrophic forgetting (CF). CF has been largely studied, and a plethora of methods have been proposed to address it on short sequences of non-overlapping tasks. In such setups, CF always leads to a quick and significant drop in performance in past tasks. Nevertheless, despite CF, recent work showed that SGD training on linear models accumulates knowledge in a CL regression setup. This phenomenon becomes especially visible when tasks reoccur. We might then wonder if DNNs trained with SGD or any standard gradient-based optimization accumulate knowledge in such a way. Such phenomena would have interesting consequences for applying DNNs to real continual scenarios. Indeed, standard gradient-based optimization methods are significantly less computationally expensive than existing CL algorithms. In this paper, we study the progressive knowledge accumulation (KA) in DNNs trained with gradient-based algorithms in long sequences of tasks with data re-occurrence. We propose a new framework, SCoLe (Scaling Continual Learning), to investigate KA and discover that catastrophic forgetting has a limited effect on DNNs trained with SGD. When trained on long sequences with data sparsely re-occurring, the overall accuracy improves, which might be counter-intuitive given the CF phenomenon. We empirically investigate KA in DNNs under various data occurrence frequencies and propose simple and scalable strategies to increase knowledge accumulation in DNNs.
A Theoretical Analysis of Catastrophic Forgetting through the NTK Overlap Matrix
Continual learning (CL) is a setting in which an agent has to learn from an incoming stream of data during its entire lifetime. Although major advances have been made in the field, one recurring problem which remains unsolved is that of Catastrophic Forgetting (CF). While the issue has been extensively studied empirically, little attention has been paid from a theoretical angle. In this paper, we show that the impact of CF increases as two tasks increasingly align. We introduce a measure of task similarity called the NTK overlap matrix which is at the core of CF. We analyze common projected gradient algorithms and demonstrate how they mitigate forgetting. Then, we propose a variant of Orthogonal Gradient Descent (OGD) which leverages structure of the data through Principal Component Analysis (PCA). Experiments support our theoretical findings and show how our method can help reduce CF on classical CL datasets.
Gradient Episodic Memory for Continual Learning
One major obstacle towards AI is the poor ability of models to solve new problems quicker, and without forgetting previously acquired knowledge. To better understand this issue, we study the problem of continual learning, where the model observes, once and one by one, examples concerning a sequence of tasks. First, we propose a set of metrics to evaluate models learning over a continuum of data. These metrics characterize models not only by their test accuracy, but also in terms of their ability to transfer knowledge across tasks. Second, we propose a model for continual learning, called Gradient Episodic Memory (GEM) that alleviates forgetting, while allowing beneficial transfer of knowledge to previous tasks. Our experiments on variants of the MNIST and CIFAR-100 datasets demonstrate the strong performance of GEM when compared to the state-of-the-art.
Interpretable Catastrophic Forgetting of Large Language Model Fine-tuning via Instruction Vector
Fine-tuning large language models (LLMs) can cause them to lose their general capabilities. However, the intrinsic mechanisms behind such forgetting remain unexplored. In this paper, we begin by examining this phenomenon by focusing on knowledge understanding and instruction following, with the latter identified as the main contributor to forgetting during fine-tuning. Consequently, we propose the Instruction Vector (IV) framework to capture model representations highly related to specific instruction-following capabilities, thereby making it possible to understand model-intrinsic forgetting. Through the analysis of IV dynamics pre and post-training, we suggest that fine-tuning mostly adds specialized reasoning patterns instead of erasing previous skills, which may appear as forgetting. Building on this insight, we develop IV-guided training, which aims to preserve original computation graph, thereby mitigating catastrophic forgetting. Empirical tests on three benchmarks confirm the efficacy of this new approach, supporting the relationship between IVs and forgetting. Our code will be made available soon.
Continual Learning via Neural Pruning
We introduce Continual Learning via Neural Pruning (CLNP), a new method aimed at lifelong learning in fixed capacity models based on neuronal model sparsification. In this method, subsequent tasks are trained using the inactive neurons and filters of the sparsified network and cause zero deterioration to the performance of previous tasks. In order to deal with the possible compromise between model sparsity and performance, we formalize and incorporate the concept of graceful forgetting: the idea that it is preferable to suffer a small amount of forgetting in a controlled manner if it helps regain network capacity and prevents uncontrolled loss of performance during the training of future tasks. CLNP also provides simple continual learning diagnostic tools in terms of the number of free neurons left for the training of future tasks as well as the number of neurons that are being reused. In particular, we see in experiments that CLNP verifies and automatically takes advantage of the fact that the features of earlier layers are more transferable. We show empirically that CLNP leads to significantly improved results over current weight elasticity based methods.
An Investigation of the Combination of Rehearsal and Knowledge Distillation in Continual Learning for Spoken Language Understanding
Continual learning refers to a dynamical framework in which a model receives a stream of non-stationary data over time and must adapt to new data while preserving previously acquired knowledge. Unluckily, neural networks fail to meet these two desiderata, incurring the so-called catastrophic forgetting phenomenon. Whereas a vast array of strategies have been proposed to attenuate forgetting in the computer vision domain, for speech-related tasks, on the other hand, there is a dearth of works. In this paper, we consider the joint use of rehearsal and knowledge distillation (KD) approaches for spoken language understanding under a class-incremental learning scenario. We report on multiple KD combinations at different levels in the network, showing that combining feature-level and predictions-level KDs leads to the best results. Finally, we provide an ablation study on the effect of the size of the rehearsal memory that corroborates the efficacy of our approach for low-resource devices.
Does Continual Learning Equally Forget All Parameters?
Distribution shift (e.g., task or domain shift) in continual learning (CL) usually results in catastrophic forgetting of neural networks. Although it can be alleviated by repeatedly replaying buffered data, the every-step replay is time-consuming. In this paper, we study which modules in neural networks are more prone to forgetting by investigating their training dynamics during CL. Our proposed metrics show that only a few modules are more task-specific and sensitively alter between tasks, while others can be shared across tasks as common knowledge. Hence, we attribute forgetting mainly to the former and find that finetuning them only on a small buffer at the end of any CL method can bring non-trivial improvement. Due to the small number of finetuned parameters, such ``Forgetting Prioritized Finetuning (FPF)'' is efficient in computation. We further propose a more efficient and simpler method that entirely removes the every-step replay and replaces them by only k-times of FPF periodically triggered during CL. Surprisingly, this ``k-FPF'' performs comparably to FPF and outperforms the SOTA CL methods but significantly reduces their computational overhead and cost. In experiments on several benchmarks of class- and domain-incremental CL, FPF consistently improves existing CL methods by a large margin, and k-FPF further excels in efficiency without degrading the accuracy. We also empirically studied the impact of buffer size, epochs per task, and finetuning modules on the cost and accuracy of our methods.
A Closer Look at Rehearsal-Free Continual Learning
Continual learning is a setting where machine learning models learn novel concepts from continuously shifting training data, while simultaneously avoiding degradation of knowledge on previously seen classes which may disappear from the training data for extended periods of time (a phenomenon known as the catastrophic forgetting problem). Current approaches for continual learning of a single expanding task (aka class-incremental continual learning) require extensive rehearsal of previously seen data to avoid this degradation of knowledge. Unfortunately, rehearsal comes at a cost to memory, and it may also violate data-privacy. Instead, we explore combining knowledge distillation and parameter regularization in new ways to achieve strong continual learning performance without rehearsal. Specifically, we take a deep dive into common continual learning techniques: prediction distillation, feature distillation, L2 parameter regularization, and EWC parameter regularization. We first disprove the common assumption that parameter regularization techniques fail for rehearsal-free continual learning of a single, expanding task. Next, we explore how to leverage knowledge from a pre-trained model in rehearsal-free continual learning and find that vanilla L2 parameter regularization outperforms EWC parameter regularization and feature distillation. Finally, we explore the recently popular ImageNet-R benchmark, and show that L2 parameter regularization implemented in self-attention blocks of a ViT transformer outperforms recent popular prompting for continual learning methods.
Elephant Neural Networks: Born to Be a Continual Learner
Catastrophic forgetting remains a significant challenge to continual learning for decades. While recent works have proposed effective methods to mitigate this problem, they mainly focus on the algorithmic side. Meanwhile, we do not fully understand what architectural properties of neural networks lead to catastrophic forgetting. This study aims to fill this gap by studying the role of activation functions in the training dynamics of neural networks and their impact on catastrophic forgetting. Our study reveals that, besides sparse representations, the gradient sparsity of activation functions also plays an important role in reducing forgetting. Based on this insight, we propose a new class of activation functions, elephant activation functions, that can generate both sparse representations and sparse gradients. We show that by simply replacing classical activation functions with elephant activation functions, we can significantly improve the resilience of neural networks to catastrophic forgetting. Our method has broad applicability and benefits for continual learning in regression, class incremental learning, and reinforcement learning tasks. Specifically, we achieves excellent performance on Split MNIST dataset in just one single pass, without using replay buffer, task boundary information, or pre-training.
A Three-regime Model of Network Pruning
Recent work has highlighted the complex influence training hyperparameters, e.g., the number of training epochs, can have on the prunability of machine learning models. Perhaps surprisingly, a systematic approach to predict precisely how adjusting a specific hyperparameter will affect prunability remains elusive. To address this gap, we introduce a phenomenological model grounded in the statistical mechanics of learning. Our approach uses temperature-like and load-like parameters to model the impact of neural network (NN) training hyperparameters on pruning performance. A key empirical result we identify is a sharp transition phenomenon: depending on the value of a load-like parameter in the pruned model, increasing the value of a temperature-like parameter in the pre-pruned model may either enhance or impair subsequent pruning performance. Based on this transition, we build a three-regime model by taxonomizing the global structure of the pruned NN loss landscape. Our model reveals that the dichotomous effect of high temperature is associated with transitions between distinct types of global structures in the post-pruned model. Based on our results, we present three case-studies: 1) determining whether to increase or decrease a hyperparameter for improved pruning; 2) selecting the best model to prune from a family of models; and 3) tuning the hyperparameter of the Sharpness Aware Minimization method for better pruning performance.
Score Forgetting Distillation: A Swift, Data-Free Method for Machine Unlearning in Diffusion Models
The machine learning community is increasingly recognizing the importance of fostering trust and safety in modern generative AI (GenAI) models. We posit machine unlearning (MU) as a crucial foundation for developing safe, secure, and trustworthy GenAI models. Traditional MU methods often rely on stringent assumptions and require access to real data. This paper introduces Score Forgetting Distillation (SFD), an innovative MU approach that promotes the forgetting of undesirable information in diffusion models by aligning the conditional scores of "unsafe" classes or concepts with those of "safe" ones. To eliminate the need for real data, our SFD framework incorporates a score-based MU loss into the score distillation objective of a pretrained diffusion model. This serves as a regularization term that preserves desired generation capabilities while enabling the production of synthetic data through a one-step generator. Our experiments on pretrained label-conditional and text-to-image diffusion models demonstrate that our method effectively accelerates the forgetting of target classes or concepts during generation, while preserving the quality of other classes or concepts. This unlearned and distilled diffusion not only pioneers a novel concept in MU but also accelerates the generation speed of diffusion models. Our experiments and studies on a range of diffusion models and datasets confirm that our approach is generalizable, effective, and advantageous for MU in diffusion models. (Warning: This paper contains sexually explicit imagery, discussions of pornography, racially-charged terminology, and other content that some readers may find disturbing, distressing, and/or offensive.)
On the Markov Property of Neural Algorithmic Reasoning: Analyses and Methods
Neural algorithmic reasoning is an emerging research direction that endows neural networks with the ability to mimic algorithmic executions step-by-step. A common paradigm in existing designs involves the use of historical embeddings in predicting the results of future execution steps. Our observation in this work is that such historical dependence intrinsically contradicts the Markov nature of algorithmic reasoning tasks. Based on this motivation, we present our ForgetNet, which does not use historical embeddings and thus is consistent with the Markov nature of the tasks. To address challenges in training ForgetNet at early stages, we further introduce G-ForgetNet, which uses a gating mechanism to allow for the selective integration of historical embeddings. Such an enhanced capability provides valuable computational pathways during the model's early training phase. Our extensive experiments, based on the CLRS-30 algorithmic reasoning benchmark, demonstrate that both ForgetNet and G-ForgetNet achieve better generalization capability than existing methods. Furthermore, we investigate the behavior of the gating mechanism, highlighting its degree of alignment with our intuitions and its effectiveness for robust performance.
Maintaining Discrimination and Fairness in Class Incremental Learning
Deep neural networks (DNNs) have been applied in class incremental learning, which aims to solve common real-world problems of learning new classes continually. One drawback of standard DNNs is that they are prone to catastrophic forgetting. Knowledge distillation (KD) is a commonly used technique to alleviate this problem. In this paper, we demonstrate it can indeed help the model to output more discriminative results within old classes. However, it cannot alleviate the problem that the model tends to classify objects into new classes, causing the positive effect of KD to be hidden and limited. We observed that an important factor causing catastrophic forgetting is that the weights in the last fully connected (FC) layer are highly biased in class incremental learning. In this paper, we propose a simple and effective solution motivated by the aforementioned observations to address catastrophic forgetting. Firstly, we utilize KD to maintain the discrimination within old classes. Then, to further maintain the fairness between old classes and new classes, we propose Weight Aligning (WA) that corrects the biased weights in the FC layer after normal training process. Unlike previous work, WA does not require any extra parameters or a validation set in advance, as it utilizes the information provided by the biased weights themselves. The proposed method is evaluated on ImageNet-1000, ImageNet-100, and CIFAR-100 under various settings. Experimental results show that the proposed method can effectively alleviate catastrophic forgetting and significantly outperform state-of-the-art methods.
Hallucination Detox: Sensitive Neuron Dropout (SeND) for Large Language Model Training
As large language models (LLMs) become increasingly deployed across various industries, concerns regarding their reliability, particularly due to hallucinations-outputs that are factually inaccurate or irrelevant to user input-have grown. Our research investigates the relationship between the training process and the emergence of hallucinations to address a key gap in existing research that focuses primarily on post hoc detection and mitigation strategies. Using models from the Pythia suite (70M-12B parameters) and several hallucination detection metrics, we analyze hallucination trends throughout training and explore LLM internal dynamics. We introduce SEnsitive Neuron Dropout (SeND), a novel training protocol designed to mitigate hallucinations by reducing variance during training. SeND achieves this by deterministically dropping neurons with significant variability on a dataset, referred to as Sensitive Neurons. In addition, we develop an unsupervised hallucination detection metric, Efficient EigenScore (EES), which approximates the traditional EigenScore in 2x speed. This efficient metric is integrated into our protocol, allowing SeND to be both computationally scalable and effective at reducing hallucinations. Our empirical evaluation demonstrates that our approach improves LLM reliability at test time by up to 40% compared to normal training while also providing an efficient method to improve factual accuracy when adapting LLMs to domains such as Wikipedia and Medical datasets.
Catastrophic Interference is Mitigated in Naturalistic Power-Law Learning Environments
Neural networks often suffer from catastrophic interference (CI): performance on previously learned tasks drops off significantly when learning a new task. This contrasts strongly with humans, who can sequentially learn new tasks without appreciably forgetting previous tasks. Prior work has explored various techniques for mitigating CI such as regularization, rehearsal, generative replay, and distillation methods. The current work takes a different approach, one guided by cognitive science research showing that in naturalistic environments, the probability of encountering a task decreases as a power-law of the time since it was last performed. We argue that a realistic evaluation of techniques for the mitigation of CI should be performed in simulated naturalistic learning environments. Thus, we evaluate the extent of mitigation of CI when training simple rehearsal-based methods in power-law environments similar to the ones humans face. Our work explores this novel rehearsal-based approach for a domain-incremental task: learning permutations in the MNIST task. We compare our rehearsal environment with other baselines to show its efficacy in promoting continual learning. Additionally, we investigate whether this environment shows forward facilitation, i.e., faster learning of later tasks. Next, we explore the robustness of our learning environment to the number of tasks, model size, and amount of data rehearsed after each task. Notably, our results show that the performance is comparable or superior to that of models trained using popular regularization methods and also to rehearsals in non-power-law environments. The benefits of this training paradigm include simplicity and the lack of a need for extra neural circuitry. In addition, because our method is orthogonal to other methods, future research can combine training in power-law environments with other continual learning mechanisms.
Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective
Training language models currently requires pre-determining a fixed compute budget because the typical cosine learning rate schedule depends on the total number of steps. In contrast, the Warmup-Stable-Decay (WSD) schedule uses a constant learning rate to produce a main branch of iterates that can in principle continue indefinitely without a pre-specified compute budget. Then, given any compute budget, one can branch out from the main branch at a proper time with a rapidly decaying learning rate to produce a strong model. Empirically, WSD generates a non-traditional loss curve: the loss remains elevated during the stable phase but sharply declines during the decay phase. Towards explaining this phenomenon, we conjecture that pretraining loss exhibits a river valley landscape, which resembles a deep valley with a river at its bottom. Under this assumption, we show that during the stable phase, the iterate undergoes large oscillations due to the high learning rate, yet it progresses swiftly along the river. During the decay phase, the rapidly dropping learning rate minimizes the iterate's oscillations, moving it closer to the river and revealing true optimization progress. Therefore, the sustained high learning rate phase and fast decaying phase are responsible for progress in the river and the mountain directions respectively, and are both critical. Our analysis predicts phenomenons consistent with empirical observations and shows that this landscape can emerge from pretraining on a simple bi-gram dataset. Inspired by the theory, we introduce WSD-S, a variant of WSD that reuses previous checkpoints' decay phases and keeps only one main branch, where we resume from a decayed checkpoint. WSD-S empirically outperforms WSD and Cyclic-Cosine in obtaining multiple language model checkpoints across various compute budgets in a single run for parameters scaling from 0.1B to 1.2B.
ZeroFlow: Overcoming Catastrophic Forgetting is Easier than You Think
Backpropagation provides a generalized configuration for overcoming catastrophic forgetting. Like, SGD and Adam are commonly used for weight updates in continual learning and continual pre-training. In practice, permission to access gradient information is not always granted (the gradient ban), such as black-box APIs, hardware limitations, and non-differentiable systems. To bridge this gap, we introduce the first benchmark ZeroFlow to evaluate gradient-free optimization algorithms for overcoming forgetting. This benchmark examines a suite of forward pass methods across multiple methods, forgetting scenarios, and datasets. We find that forward passes alone are enough to overcome forgetting. Our findings reveal new optimization principles that highlight the potential of forward-pass in mitigating forgetting, managing task conflicts, and reducing memory demands, alongside novel enhancements that further mitigate forgetting with just one forward pass. This work provides essential insights and tools for advancing forward pass methods to overcome forgetting.
Reduce Catastrophic Forgetting of Dense Retrieval Training with Teleportation Negatives
In this paper, we investigate the instability in the standard dense retrieval training, which iterates between model training and hard negative selection using the being-trained model. We show the catastrophic forgetting phenomena behind the training instability, where models learn and forget different negative groups during training iterations. We then propose ANCE-Tele, which accumulates momentum negatives from past iterations and approximates future iterations using lookahead negatives, as "teleportations" along the time axis to smooth the learning process. On web search and OpenQA, ANCE-Tele outperforms previous state-of-the-art systems of similar size, eliminates the dependency on sparse retrieval negatives, and is competitive among systems using significantly more (50x) parameters. Our analysis demonstrates that teleportation negatives reduce catastrophic forgetting and improve convergence speed for dense retrieval training. Our code is available at https://github.com/OpenMatch/ANCE-Tele.
DUCK: Distance-based Unlearning via Centroid Kinematics
Machine Unlearning is rising as a new field, driven by the pressing necessity of ensuring privacy in modern artificial intelligence models. This technique primarily aims to eradicate any residual influence of a specific subset of data from the knowledge acquired by a neural model during its training. This work introduces a novel unlearning algorithm, denoted as Distance-based Unlearning via Centroid Kinematics (DUCK), which employs metric learning to guide the removal of samples matching the nearest incorrect centroid in the embedding space. Evaluation of the algorithm's performance is conducted across various benchmark datasets in two distinct scenarios, class removal, and homogeneous sampling removal, obtaining state-of-the-art performance. We also introduce a novel metric, called Adaptive Unlearning Score (AUS), encompassing not only the efficacy of the unlearning process in forgetting target data but also quantifying the performance loss relative to the original model. Additionally, we conducted a thorough investigation of the unlearning mechanism in DUCK, examining its impact on the organization of the feature space and employing explainable AI techniques for deeper insights.
Prototype-Sample Relation Distillation: Towards Replay-Free Continual Learning
In Continual learning (CL) balancing effective adaptation while combating catastrophic forgetting is a central challenge. Many of the recent best-performing methods utilize various forms of prior task data, e.g. a replay buffer, to tackle the catastrophic forgetting problem. Having access to previous task data can be restrictive in many real-world scenarios, for example when task data is sensitive or proprietary. To overcome the necessity of using previous tasks' data, in this work, we start with strong representation learning methods that have been shown to be less prone to forgetting. We propose a holistic approach to jointly learn the representation and class prototypes while maintaining the relevance of old class prototypes and their embedded similarities. Specifically, samples are mapped to an embedding space where the representations are learned using a supervised contrastive loss. Class prototypes are evolved continually in the same latent space, enabling learning and prediction at any point. To continually adapt the prototypes without keeping any prior task data, we propose a novel distillation loss that constrains class prototypes to maintain relative similarities as compared to new task data. This method yields state-of-the-art performance in the task-incremental setting, outperforming methods relying on large amounts of data, and provides strong performance in the class-incremental setting without using any stored data points.
Continual Learning with Pretrained Backbones by Tuning in the Input Space
The intrinsic difficulty in adapting deep learning models to non-stationary environments limits the applicability of neural networks to real-world tasks. This issue is critical in practical supervised learning settings, such as the ones in which a pre-trained model computes projections toward a latent space where different task predictors are sequentially learned over time. As a matter of fact, incrementally fine-tuning the whole model to better adapt to new tasks usually results in catastrophic forgetting, with decreasing performance over the past experiences and losing valuable knowledge from the pre-training stage. In this paper, we propose a novel strategy to make the fine-tuning procedure more effective, by avoiding to update the pre-trained part of the network and learning not only the usual classification head, but also a set of newly-introduced learnable parameters that are responsible for transforming the input data. This process allows the network to effectively leverage the pre-training knowledge and find a good trade-off between plasticity and stability with modest computational efforts, thus especially suitable for on-the-edge settings. Our experiments on four image classification problems in a continual learning setting confirm the quality of the proposed approach when compared to several fine-tuning procedures and to popular continual learning methods.
Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking
Recent work by Power et al. (2022) highlighted a surprising "grokking" phenomenon in learning arithmetic tasks: a neural net first "memorizes" the training set, resulting in perfect training accuracy but near-random test accuracy, and after training for sufficiently longer, it suddenly transitions to perfect test accuracy. This paper studies the grokking phenomenon in theoretical setups and shows that it can be induced by a dichotomy of early and late phase implicit biases. Specifically, when training homogeneous neural nets with large initialization and small weight decay on both classification and regression tasks, we prove that the training process gets trapped at a solution corresponding to a kernel predictor for a long time, and then a very sharp transition to min-norm/max-margin predictors occurs, leading to a dramatic change in test accuracy.
Unforgettable Generalization in Language Models
When language models (LMs) are trained to forget (or "unlearn'') a skill, how precisely does their behavior change? We study the behavior of transformer LMs in which tasks have been forgotten via fine-tuning on randomized labels. Such LMs learn to generate near-random predictions for individual examples in the "training'' set used for forgetting. Across tasks, however, LMs exhibit extreme variability in whether LM predictions change on examples outside the training set. In some tasks (like entailment classification), forgetting generalizes robustly, and causes models to produce uninformative predictions on new task instances; in other tasks (like physical commonsense reasoning and scientific question answering) forgetting affects only the training examples, and models continue to perform the "forgotten'' task accurately even for examples very similar to those that appeared in the training set. Dataset difficulty is not predictive of whether a behavior can be forgotten; instead, generalization in forgetting is (weakly) predicted by the confidence of LMs' initial task predictions and the variability of LM representations of training data, with low confidence and low variability both associated with greater generalization. Perhaps most surprisingly, random-label forgetting appears to be somewhat insensitive to the contents of the training set: for example, models trained on science questions with random labels continue to answer other science questions accurately, but begin to produce random labels on entailment classification tasks. Finally, we show that even generalizable forgetting is shallow: linear probes trained on LMs' representations can still perform tasks reliably after forgetting. Our results highlight the difficulty and unpredictability of performing targeted skill removal from models via fine-tuning.
Learning to Prompt for Continual Learning
The mainstream paradigm behind continual learning has been to adapt the model parameters to non-stationary data distributions, where catastrophic forgetting is the central challenge. Typical methods rely on a rehearsal buffer or known task identity at test time to retrieve learned knowledge and address forgetting, while this work presents a new paradigm for continual learning that aims to train a more succinct memory system without accessing task identity at test time. Our method learns to dynamically prompt (L2P) a pre-trained model to learn tasks sequentially under different task transitions. In our proposed framework, prompts are small learnable parameters, which are maintained in a memory space. The objective is to optimize prompts to instruct the model prediction and explicitly manage task-invariant and task-specific knowledge while maintaining model plasticity. We conduct comprehensive experiments under popular image classification benchmarks with different challenging continual learning settings, where L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer and is directly applicable to challenging task-agnostic continual learning. Source code is available at https://github.com/google-research/l2p.
Where to find Grokking in LLM Pretraining? Monitor Memorization-to-Generalization without Test
Grokking, i.e., test performance keeps improving long after training loss converged, has been recently witnessed in neural network training, making the mechanism of generalization and other emerging capabilities such as reasoning mysterious. While prior studies usually train small models on a few toy or highly-specific tasks for thousands of epochs, we conduct the first study of grokking on checkpoints during one-pass pretraining of a 7B large language model (LLM), i.e., OLMoE. We compute the training loss and evaluate generalization on diverse benchmark tasks, including math reasoning, code generation, and commonsense/domain-specific knowledge retrieval tasks. Our study, for the first time, verifies that grokking still happens in the pretraining of large-scale foundation models, though different data may enter grokking stages asynchronously. We further demystify grokking's "emergence of generalization" by investigating LLM internal dynamics. Specifically, we find that training samples' pathways (i.e., expert choices across layers) evolve from random, instance-specific to more structured and shareable between samples during grokking. Also, the complexity of a sample's pathway reduces despite the converged loss. These indicate a memorization-to-generalization conversion, providing a mechanistic explanation of delayed generalization. In the study, we develop two novel metrics to quantify pathway distance and the complexity of a single pathway. We show their ability to predict the generalization improvement on diverse downstream tasks. They are efficient, simple to compute and solely dependent on training data. Hence, they have practical value for pretraining, enabling us to monitor the generalization performance without finetuning and test. Theoretically, we show that more structured pathways reduce model complexity and improve the generalization bound.
Learning Continually by Spectral Regularization
Loss of plasticity is a phenomenon where neural networks become more difficult to train during the course of learning. Continual learning algorithms seek to mitigate this effect by sustaining good predictive performance while maintaining network trainability. We develop new techniques for improving continual learning by first reconsidering how initialization can ensure trainability during early phases of learning. From this perspective, we derive new regularization strategies for continual learning that ensure beneficial initialization properties are better maintained throughout training. In particular, we investigate two new regularization techniques for continual learning: (i) Wasserstein regularization toward the initial weight distribution, which is less restrictive than regularizing toward initial weights; and (ii) regularizing weight matrix singular values, which directly ensures gradient diversity is maintained throughout training. We present an experimental analysis that shows these alternative regularizers can improve continual learning performance across a range of supervised learning tasks and model architectures. The alternative regularizers prove to be less sensitive to hyperparameters while demonstrating better training in individual tasks, sustaining trainability as new tasks arrive, and achieving better generalization performance.
Towards General Purpose Medical AI: Continual Learning Medical Foundation Model
Inevitable domain and task discrepancies in real-world scenarios can impair the generalization performance of the pre-trained deep models for medical data. Therefore, we audaciously propose that we should build a general-purpose medical AI system that can be seamlessly adapted to downstream domains/tasks. Since the domain/task adaption procedures usually involve additional labeling work for the target data, designing a data-efficient adaption algorithm is desired to save the cost of transferring the learned knowledge. Our recent work found that vision-language models (VLMs) are efficient learners with extraordinary cross-domain ability. Therefore, in this work, we further explore the possibility of leveraging pre-trained VLMs as medical foundation models for building general-purpose medical AI, where we thoroughly investigate three machine-learning paradigms, i.e., domain/task-specialized learning, joint learning, and continual learning, for training the VLMs and evaluate their generalization performance on cross-domain and cross-task test sets. To alleviate the catastrophic forgetting during sequential training, we employ rehearsal learning and receive a sharp boost in terms of generalization capability. In a nutshell, our empirical evidence suggests that continual learning may be a practical and efficient learning paradigm for the medical foundation model. And we hope researchers can use our empirical evidence as basement to further explore the path toward medical foundation model.
When, Why and How Much? Adaptive Learning Rate Scheduling by Refinement
Learning rate schedules used in practice bear little resemblance to those recommended by theory. We close much of this theory/practice gap, and as a consequence are able to derive new problem-adaptive learning rate schedules. Our key technical contribution is a refined analysis of learning rate schedules for a wide class of optimization algorithms (including SGD). In contrast to most prior works that study the convergence of the average iterate, we study the last iterate, which is what most people use in practice. When considering only worst-case analysis, our theory predicts that the best choice is the linear decay schedule: a popular choice in practice that sets the stepsize proportionally to 1 - t/T, where t is the current iteration and T is the total number of steps. To go beyond this worst-case analysis, we use the observed gradient norms to derive schedules refined for any particular task. These refined schedules exhibit learning rate warm-up and rapid learning rate annealing near the end of training. Ours is the first systematic approach to automatically yield both of these properties. We perform the most comprehensive evaluation of learning rate schedules to date, evaluating across 10 diverse deep learning problems, a series of LLMs, and a suite of logistic regression problems. We validate that overall, the linear-decay schedule matches or outperforms all commonly used default schedules including cosine annealing, and that our schedule refinement method gives further improvements.
Think Before You Act: Decision Transformers with Internal Working Memory
Large language model (LLM)-based decision-making agents have shown the ability to generalize across multiple tasks. However, their performance relies on massive data and compute. We argue that this inefficiency stems from the forgetting phenomenon, in which a model memorizes its behaviors in parameters throughout training. As a result, training on a new task may deteriorate the model's performance on previous tasks. In contrast to LLMs' implicit memory mechanism, the human brain utilizes distributed memory storage, which helps manage and organize multiple skills efficiently, mitigating the forgetting phenomenon. Thus inspired, we propose an internal working memory module to store, blend, and retrieve information for different downstream tasks. Evaluation results show that the proposed method improves training efficiency and generalization in both Atari games and meta-world object manipulation tasks. Moreover, we demonstrate that memory fine-tuning further enhances the adaptability of the proposed architecture.