How NVIDIA Pruned and Distilled Llama 3.1 to Create Minitron 4B and 8B
- Rifx.Online
- Programming , Machine Learning , Data Science
- 10 Nov, 2024
The new models are using state of the art pruning and distillation techniques.
I recently started an AI-focused educational newsletter, that already has over 170,000 subscribers. TheSequence is a no-BS (meaning no hype, no news, etc) ML-oriented newsletter that takes 5 minutes to read. The goal is to keep you up to date with machine learning projects, research papers, and concepts. Please give it a try by subscribing below:
We are regularly dazzled by the advancements in large language models(LLMs) particularly the ones with a massive number of parameters. However, executing 70B+ parameter models for inference results cost prohibited for most organizations. As a result, we have seen a growing influence of smaller language models(SLMs) that make it more cost effective to execute inference workloads. However, there is not always possible to pretrain SLMs from scratch as there are major challenges in terms of data collection, pretraining pipelines and many others. A popular alternative have been to start with larger LLMs and distill them to smaller models. Pruning and distillation are two of the most popular techniques in this area. Recently, NVIDIA released two models called Minitron-8B and Minitron-4B based on distilled versions of Llama 3.1–450B.
Minitron focuses on reducing the size of AI models through pruning and distillation, making them more efficient without sacrificing too much accuracy. Pruning reduces a model’s size by either cutting layers (depth pruning) or removing neurons, attention heads, or embedding channels (width pruning). To recover some lost accuracy, retraining is often necessary after pruning.
Distillation is a related technique where a smaller model, known as the student, learns from a larger, complex model called the teacher. The goal is to create a more compact model that retains much of the predictive capability of the larger one, while being faster and less demanding on resources.
Approaches to Distillation: Classical vs. SDG Fine-tuning
Minitron identifies two key styles of distillation. One approach is SDG fine-tuning, where a smaller, pretrained student model is refined using data generated by a larger teacher model. In this method, the student mimics the final token predicted by the teacher, as seen in some popular tutorials and AI platforms.
The other approach, classical knowledge distillation, is more involved. Instead of focusing solely on the predicted token, the student model tries to replicate various internal states of the teacher model. This technique provides more detailed feedback during training, resulting in better accuracy. However, implementing this method requires specific support in the training framework, as it involves handling large data from the teacher’s internal states.
These two methods aren’t mutually exclusive but can complement each other. Minitron’s main emphasis is on the classical knowledge distillation approach.
Pruning and Distillation Workflow
To create more efficient models, Minitron combines pruning with classical knowledge distillation. Starting with a larger model, such as a 15B parameter model, Minitron evaluates the importance of different components — layers, neurons, and more — then reduces the model to a smaller size, like an 8B model. The smaller model undergoes a light retraining process where it learns from the original, larger model. This process can be repeated to further reduce the model size, eventually producing even smaller versions, such as a 4B model.
The pruning and distillation process is iterative, with each smaller model serving as the basis for the next round of compression and retraining.
Pruning Impact
Pruning a model effectively requires understanding which parts of it are essential. Minitron uses an approach based on activation data to estimate the importance of various components — layers, neurons, attention heads, and embedding channels — using a small dataset. This method only requires forward propagation, making it simpler and more cost-effective than techniques that rely on backward propagation and gradient calculations.
While it’s possible to alternate between pruning and importance estimation for different parts of the model, Minitron found that a single round of importance estimation was sufficient in most cases.
Retraining Using Classical Knowledge Distillation
After pruning, Minitron retrains the smaller model using classical knowledge distillation. This involves teaching the pruned model by minimizing losses at various stages of the model, including the embedding output, logits, and specific losses in the transformer architecture. The student model learns from the unpruned teacher model by comparing outputs at different layers.
From extensive experimentation, Minitron has distilled several best practices for compressing language models:
· Model Sizing: Start by training the largest model, then gradually prune and distill it to create smaller versions.
· Pruning Strategy: Focus on width pruning over depth pruning, especially for models up to 15B parameters. Single-shot importance estimation is usually sufficient.
· Retraining: Retrain using distillation loss instead of conventional training. When pruning layers significantly, use a combination of losses from logits, intermediate states, and embeddings. For smaller reductions in depth, stick to logit-only distillation.
Minitron applied these techniques to the Llama 3.1 model family, which includes models ranging from 405B to 8B parameters. Specifically, they focused on distilling the 8B model to a more efficient 4B version.
Fine-tuning the Teacher
Before pruning, Minitron fine-tuned the 8B model to account for shifts in the data distribution from the original training set. Without this step, the teacher model may not offer the best guidance to the student during distillation.
Depth Pruning
To reduce the 8B model to 4B, Minitron pruned 16 layers, assessing their importance by removing them one by one and tracking the impact on performance. They found that layers at both the beginning and end of the model were most critical to maintaining accuracy. Based on this analysis, Minitron removed a specific set of layers for the final 4B model.
Width Pruning
In addition to depth pruning, Minitron also pruned along the width dimension, targeting attention heads, embedding channels, and hidden layers. After pruning, retraining helped recover some of the performance lost in the initial pruning step. Interestingly, although width pruning initially led to higher loss than depth pruning, retraining allowed the model to recover more effectively over time.
The Results
NVIDIA evaluated the Minitron models on several benchmarks with results that matched the performance of baselines models.
The Minitron 4B-8B showcased the potential of distillation and pruning to build smaller and more efficient models. There are also major challenges with this approach but I think, overall, it sets an important baseline for the industry.