|
--- |
|
license: mit |
|
library_name: transformers |
|
pipeline_tag: text-generation |
|
base_model: |
|
- nvidia/Llama-3.1-Minitron-4B-Depth-Base |
|
datasets: |
|
- BAAI/Infinity-Instruct |
|
--- |
|
|
|
We fine-tune `nvidia/Llama-3.1-Minitron-4B-Depth-Base` with the LLM-Neo method, which combines LoRA and KD. Training data is sampled from `BAAI/Infinity-Instruct` for 100k lines. |
|
|
|
This repository contains the model described in the paper [LLM-Neo: Parameter Efficient Knowledge Distillation for Large Language Models](https://hf.co/papers/2411.06839). |
|
The project page is available [here](https://huggingface.co/collections/yang31210999/llm-neo-66e3c882f5579b829ff57eba) and the Github repository is available [here](https://github.com/yang3121099/LLM-Neo). |
|
|
|
## Basic Usage |
|
|
|
This example demonstrates generating text using the model. You'll need to install the necessary libraries first: `pip install transformers`. |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|
import torch |
|
|
|
model_path = "yang31210999/Llama-3.1-Minitron-4B-Depth-Neo-10w" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) |
|
|
|
prompt = "Once upon a time" |
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
generation_config = GenerationConfig( |
|
max_new_tokens=50, do_sample=True, temperature=0.7 |
|
) |
|
|
|
outputs = model.generate(**inputs, generation_config=generation_config) |
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
print(generated_text) |
|
|
|
``` |
|
|
|
## Benchmarks |
|
|
|
In this section, we report the results for `Llama-3.1-Minitron-4B-Depth-Neo-10w` on standard automatic benchmarks. For all the evaluations, we use the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) library. |
|
|
|
### Evaluation results |
|
|
|
<table> |
|
<tr> |
|
<td><strong>Category</strong> |
|
</td> |
|
<td><strong>Benchmark</strong> |
|
</td> |
|
<td><strong>Version</strong> |
|
</td> |
|
<td><strong>n-shot</strong> |
|
</td> |
|
<td><strong>Metric</strong> |
|
</td> |
|
<td><strong>Value</strong> |
|
</td> |
|
<td><strong>Stderr</strong> |
|
</td> |
|
</tr> |
|
<tr> |
|
<td rowspan="3" >BBH |
|
</td> |
|
<td>BBH (General)</td> |
|
<td>N/A</td> |
|
<td>3</td> |
|
<td>exact_match</td> |
|
<td>0.4729</td> |
|
<td>± 0.0055</td> |
|
</tr> |
|
<tr> |
|
<td>BBH (Boolean Expressions)</td> |
|
<td>2</td> |
|
<td>3</td> |
|
<td>exact_match</td> |
|
<td>0.8120</td> |
|
<td>± 0.0248</td> |
|
</tr> |
|
<tr> |
|
<td>BBH (Date Understanding)</td> |
|
<td>2</td> |
|
<td>3</td> |
|
<td>exact_match</td> |
|
<td>0.6600</td> |
|
<td>± 0.0300</td> |
|
</tr> |
|
<tr> |
|
<td rowspan="4" >CEVAL |
|
</td> |
|
<td>CEVAL (General)</td> |
|
<td>N/A</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.4413</td> |
|
<td>± 0.0135</td> |
|
</tr> |
|
<tr> |
|
<td>CEVAL (Accountant)</td> |
|
<td>1</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.3469</td> |
|
<td>± 0.0687</td> |
|
</tr> |
|
<tr> |
|
<td>CEVAL (Advanced Mathematics)</td> |
|
<td>1</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.4737</td> |
|
<td>± 0.1177</td> |
|
</tr> |
|
<tr> |
|
<td>CEVAL (Art Studies)</td> |
|
<td>1</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.4545</td> |
|
<td>± 0.0880</td> |
|
</tr> |
|
<tr> |
|
<td rowspan="3" >MMLU |
|
</td> |
|
<td>MMLU (General)</td> |
|
<td>N/A</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.6048</td> |
|
<td>± 0.0039</td> |
|
</tr> |
|
<tr> |
|
<td>MMLU (Humanities)</td> |
|
<td>N/A</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.5552</td> |
|
<td>± 0.0067</td> |
|
</tr> |
|
<tr> |
|
<td>MMLU (STEM)</td> |
|
<td>N/A</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.5214</td> |
|
<td>± 0.0086</td> |
|
</tr> |
|
<tr> |
|
<td rowspan="2" >CMMLU |
|
</td> |
|
<td>CMMLU (General)</td> |
|
<td>N/A</td> |
|
<td>0</td> |
|
<td>acc</td> |
|
<td>0.3548</td> |
|
<td>± 0.0044</td> |
|
</tr> |
|
<tr> |
|
<td>CMMLU (Normalized)</td> |
|
<td>N/A</td> |
|
<td>0</td> |
|
<td>acc_norm</td> |
|
<td>0.3548</td> |
|
<td>± 0.0044</td> |
|
</tr> |
|
</table> |