Commit
·
1492bd8
1
Parent(s):
2ae2dab
update readme
Browse files
README.md
CHANGED
@@ -66,10 +66,13 @@ v0.23.0 <br>
|
|
66 |
* [Human] <br>
|
67 |
|
68 |
## Medusa Speculative Decoding and Post Training Quantization
|
69 |
-
Synthesized data was obtained from a FP8 quantized version of Meta-Llama-3.1-8B-Instruct, which is then used to finetune the Medusa heads. This model was then obtained by quantizing the weights and activations of Meta-Llama-3.1-8B-Instruct together with the Medusa heads to FP8 data type, ready for inference with TensorRT-LLM in Medusa speculative decoding mode. Only the weights and activations of the linear operators within transformers blocks and Medusa heads are quantized. This optimization reduces the number of bits per parameter from 16 to 8, reducing the disk size and GPU memory requirements by approximately 50%.
|
|
|
|
|
70 |
|
71 |
## Usage
|
72 |
-
To
|
|
|
73 |
```python
|
74 |
### Generate Text Using Medusa Decoding
|
75 |
|
@@ -106,7 +109,7 @@ def main():
|
|
106 |
[4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
|
107 |
[0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
|
108 |
)
|
109 |
-
llm = LLM(model="
|
110 |
build_config=build_config,
|
111 |
speculative_config=speculative_config)
|
112 |
|
@@ -124,10 +127,8 @@ if __name__ == '__main__':
|
|
124 |
|
125 |
```
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
* Throughputs evaluation:
|
130 |
-
Please refer to the [TensorRT-LLM benchmarking documentation](https://github.com/NVIDIA/TensorRT-LLM/blob/main/benchmarks/Suite.md) for details.
|
131 |
|
132 |
## Evaluation
|
133 |
The accuracy (MMLU, 5-shot) and Medusa acceptance rate benchmark results are presented in the table below:
|
@@ -143,4 +144,3 @@ The accuracy (MMLU, 5-shot) and Medusa acceptance rate benchmark results are pre
|
|
143 |
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
|
144 |
|
145 |
Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.NVIDIA.com/en-us/support/submit-security-vulnerability/).
|
146 |
-
|
|
|
66 |
* [Human] <br>
|
67 |
|
68 |
## Medusa Speculative Decoding and Post Training Quantization
|
69 |
+
Synthesized data was obtained from a FP8 quantized version of Meta-Llama-3.1-8B-Instruct, which is then used to finetune the Medusa heads. This model was then obtained by quantizing the weights and activations of Meta-Llama-3.1-8B-Instruct together with the Medusa heads to FP8 data type, ready for inference with TensorRT-LLM in Medusa speculative decoding mode. Only the weights and activations of the linear operators within transformers blocks and Medusa heads are quantized. This optimization reduces the number of bits per parameter from 16 to 8, reducing the disk size and GPU memory requirements by approximately 50%.
|
70 |
+
|
71 |
+
Medusa heads are used to predict candidate tokens beyond the next token. In the generation step, each Medusa head generates a distribution of tokens beyond the previous. Then a tree-based attention mechanism samples some candidate sequences for the original model to validate. The longest accepted candidate sequence is selected so that more than 1 token is returned in the generation step. The number of tokens generated in each step is called acceptance rate.
|
72 |
|
73 |
## Usage
|
74 |
+
To run inference with [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (supported from [v0.17](https://github.com/NVIDIA/TensorRT-LLM/tree/v0.17.0)), we recommend using LLM APIs as shown in [this example](https://github.com/NVIDIA/TensorRT-LLM/blob/v0.17.0/examples/llm-api/llm_medusa_decoding.py#L34) with ` python llm_medusa_decoding.py --use_modelopt_ckpt` or below. The LLM APIs abstract away steps like checkpoint conversion, engine building, and inference.
|
75 |
+
|
76 |
```python
|
77 |
### Generate Text Using Medusa Decoding
|
78 |
|
|
|
109 |
[4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
|
110 |
[0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
|
111 |
)
|
112 |
+
llm = LLM(model="nvidia/Llama-3.1-8B-Medusa-FP8",
|
113 |
build_config=build_config,
|
114 |
speculative_config=speculative_config)
|
115 |
|
|
|
127 |
|
128 |
```
|
129 |
|
130 |
+
Alternatively, you can follow the [sample CLIs for Medusa decoding](https://github.com/NVIDIA/TensorRT-LLM/tree/v0.17.0/examples/medusa#usage) in the TensorRT-LLM GitHub repo.
|
131 |
+
Support in [TensorRT-LLM benchmarking](https://nvidia.github.io/TensorRT-LLM/performance/perf-benchmarking.html) with `trtllm-bench` is coming soon.
|
|
|
|
|
132 |
|
133 |
## Evaluation
|
134 |
The accuracy (MMLU, 5-shot) and Medusa acceptance rate benchmark results are presented in the table below:
|
|
|
144 |
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
|
145 |
|
146 |
Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.NVIDIA.com/en-us/support/submit-security-vulnerability/).
|
|