Type something to search...
How NVIDIA Pruned and Distilled Llama 3.1 to Create Minitron 4B and 8B

How NVIDIA Pruned and Distilled Llama 3.1 to Create Minitron 4B and 8B

The new models are using state of the art pruning and distillation techniques.

I recently started an AI-focused educational newsletter, that already has over 170,000 subscribers. TheSequence is a no-BS (meaning no hype, no news, etc) ML-oriented newsletter that takes 5 minutes to read. The goal is to keep you up to date with machine learning projects, research papers, and concepts. Please give it a try by subscribing below:

We are regularly dazzled by the advancements in large language models(LLMs) particularly the ones with a massive number of parameters. However, executing 70B+ parameter models for inference results cost prohibited for most organizations. As a result, we have seen a growing influence of smaller language models(SLMs) that make it more cost effective to execute inference workloads. However, there is not always possible to pretrain SLMs from scratch as there are major challenges in terms of data collection, pretraining pipelines and many others. A popular alternative have been to start with larger LLMs and distill them to smaller models. Pruning and distillation are two of the most popular techniques in this area. Recently, NVIDIA released two models called Minitron-8B and Minitron-4B based on distilled versions of Llama 3.1–450B.

Minitron focuses on reducing the size of AI models through pruning and distillation, making them more efficient without sacrificing too much accuracy. Pruning reduces a model’s size by either cutting layers (depth pruning) or removing neurons, attention heads, or embedding channels (width pruning). To recover some lost accuracy, retraining is often necessary after pruning.

Distillation is a related technique where a smaller model, known as the student, learns from a larger, complex model called the teacher. The goal is to create a more compact model that retains much of the predictive capability of the larger one, while being faster and less demanding on resources.

Approaches to Distillation: Classical vs. SDG Fine-tuning

Minitron identifies two key styles of distillation. One approach is SDG fine-tuning, where a smaller, pretrained student model is refined using data generated by a larger teacher model. In this method, the student mimics the final token predicted by the teacher, as seen in some popular tutorials and AI platforms.

The other approach, classical knowledge distillation, is more involved. Instead of focusing solely on the predicted token, the student model tries to replicate various internal states of the teacher model. This technique provides more detailed feedback during training, resulting in better accuracy. However, implementing this method requires specific support in the training framework, as it involves handling large data from the teacher’s internal states.

These two methods aren’t mutually exclusive but can complement each other. Minitron’s main emphasis is on the classical knowledge distillation approach.

Pruning and Distillation Workflow

To create more efficient models, Minitron combines pruning with classical knowledge distillation. Starting with a larger model, such as a 15B parameter model, Minitron evaluates the importance of different components — layers, neurons, and more — then reduces the model to a smaller size, like an 8B model. The smaller model undergoes a light retraining process where it learns from the original, larger model. This process can be repeated to further reduce the model size, eventually producing even smaller versions, such as a 4B model.

The pruning and distillation process is iterative, with each smaller model serving as the basis for the next round of compression and retraining.

Pruning Impact

Pruning a model effectively requires understanding which parts of it are essential. Minitron uses an approach based on activation data to estimate the importance of various components — layers, neurons, attention heads, and embedding channels — using a small dataset. This method only requires forward propagation, making it simpler and more cost-effective than techniques that rely on backward propagation and gradient calculations.

While it’s possible to alternate between pruning and importance estimation for different parts of the model, Minitron found that a single round of importance estimation was sufficient in most cases.

Retraining Using Classical Knowledge Distillation

After pruning, Minitron retrains the smaller model using classical knowledge distillation. This involves teaching the pruned model by minimizing losses at various stages of the model, including the embedding output, logits, and specific losses in the transformer architecture. The student model learns from the unpruned teacher model by comparing outputs at different layers.

From extensive experimentation, Minitron has distilled several best practices for compressing language models:

· Model Sizing: Start by training the largest model, then gradually prune and distill it to create smaller versions.

· Pruning Strategy: Focus on width pruning over depth pruning, especially for models up to 15B parameters. Single-shot importance estimation is usually sufficient.

· Retraining: Retrain using distillation loss instead of conventional training. When pruning layers significantly, use a combination of losses from logits, intermediate states, and embeddings. For smaller reductions in depth, stick to logit-only distillation.

Minitron applied these techniques to the Llama 3.1 model family, which includes models ranging from 405B to 8B parameters. Specifically, they focused on distilling the 8B model to a more efficient 4B version.

Fine-tuning the Teacher

Before pruning, Minitron fine-tuned the 8B model to account for shifts in the data distribution from the original training set. Without this step, the teacher model may not offer the best guidance to the student during distillation.

Depth Pruning

To reduce the 8B model to 4B, Minitron pruned 16 layers, assessing their importance by removing them one by one and tracking the impact on performance. They found that layers at both the beginning and end of the model were most critical to maintaining accuracy. Based on this analysis, Minitron removed a specific set of layers for the final 4B model.

Width Pruning

In addition to depth pruning, Minitron also pruned along the width dimension, targeting attention heads, embedding channels, and hidden layers. After pruning, retraining helped recover some of the performance lost in the initial pruning step. Interestingly, although width pruning initially led to higher loss than depth pruning, retraining allowed the model to recover more effectively over time.

The Results

NVIDIA evaluated the Minitron models on several benchmarks with results that matched the performance of baselines models.

The Minitron 4B-8B showcased the potential of distillation and pruning to build smaller and more efficient models. There are also major challenges with this approach but I think, overall, it sets an important baseline for the industry.

Related Posts

10 Creative Ways to Use ChatGPT Search The Web Feature

10 Creative Ways to Use ChatGPT Search The Web Feature

For example, prompts and outputs Did you know you can use the “search the web” feature of ChatGPT for many tasks other than your basic web search? For those who don't know, ChatGPT’s new

Read More
📚 10 Must-Learn Skills to Stay Ahead in AI and Tech 🚀

📚 10 Must-Learn Skills to Stay Ahead in AI and Tech 🚀

In an industry as dynamic as AI and tech, staying ahead means constantly upgrading your skills. Whether you’re aiming to dive deep into AI model performance, master data analysis, or transform trad

Read More
10 Powerful Perplexity AI Prompts to Automate Your Marketing Tasks

10 Powerful Perplexity AI Prompts to Automate Your Marketing Tasks

In today’s fast-paced digital world, marketers are always looking for smarter ways to streamline their efforts. Imagine having a personal assistant who can create audience profiles, suggest mar

Read More
10+ Top ChatGPT Prompts for UI/UX Designers

10+ Top ChatGPT Prompts for UI/UX Designers

AI technologies, such as machine learning, natural language processing, and data analytics, are redefining traditional design methodologies. From automating repetitive tasks to enabling personal

Read More
100 AI Tools to Finish Months of Work in Minutes

100 AI Tools to Finish Months of Work in Minutes

The rapid advancements in artificial intelligence (AI) have transformed how businesses operate, allowing people to complete tasks that once took weeks or months in mere minutes. From content creat

Read More
17 Mindblowing GitHub Repositories You Never Knew Existed

17 Mindblowing GitHub Repositories You Never Knew Existed

Github Hidden Gems!! Repositories To Bookmark Right Away Learning to code is relatively easy, but mastering the art of writing better code is much tougher. GitHub serves as a treasur

Read More