Update README.md
Browse files
README.md
CHANGED
@@ -89,13 +89,131 @@ headline = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
89 |
- GPU training with gradient checkpointing
|
90 |
- Parallel data loading with 8 workers
|
91 |
|
|
|
|
|
|
|
|
|
92 |
## Evaluation
|
93 |
|
94 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
- ROUGE scores for headline similarity
|
96 |
- Human evaluation for headline appropriateness
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
## Technical Specifications
|
101 |
|
|
|
89 |
- GPU training with gradient checkpointing
|
90 |
- Parallel data loading with 8 workers
|
91 |
|
92 |
+
I'll help you add the evaluation information to your markdown file in a clearer tabular format.
|
93 |
+
|
94 |
+
Here's how you can structure the evaluation section:
|
95 |
+
|
96 |
## Evaluation
|
97 |
|
98 |
+
### ROUGE Score Comparison
|
99 |
+
|
100 |
+
| Metric | Base Model | Finetuned Model | Improvement |
|
101 |
+
|---------|------------|-----------------|-------------|
|
102 |
+
| ROUGE-1 | 2.85 | 4.67 | +1.82 |
|
103 |
+
| ROUGE-2 | 0.25 | 0.41 | +0.17 |
|
104 |
+
| ROUGE-L | 2.84 | 4.65 | +1.81 |
|
105 |
+
|
106 |
+
### Model Prediction Comparison using Bigger model for evaluation
|
107 |
+
|
108 |
+
| Category | Count | Percentage |
|
109 |
+
|-------------------|-------|------------|
|
110 |
+
| Total samples | 5962 | 100% |
|
111 |
+
| Same predictions | 1 | 0.02% |
|
112 |
+
| Better predictions| 4697 | 78.78% |
|
113 |
+
| Worse predictions | 1264 | 21.20% |
|
114 |
+
|
115 |
+
### Evaluation Methods
|
116 |
- ROUGE scores for headline similarity
|
117 |
- Human evaluation for headline appropriateness
|
118 |
|
119 |
|
120 |
+
## Inference
|
121 |
+
|
122 |
+
#### Running the model on a GPU using different precisions
|
123 |
+
|
124 |
+
* _Using `torch.float16`_
|
125 |
+
|
126 |
+
```python
|
127 |
+
# pip install accelerate
|
128 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
129 |
+
|
130 |
+
tokenizer = AutoTokenizer.from_pretrained("saidines12/telugu-news-headline-generation")
|
131 |
+
model = AutoModelForCausalLM.from_pretrained("saidines12/telugu-news-headline-generation", device_map="auto", revision="float16")
|
132 |
+
|
133 |
+
input_text = "Generate relevant, interesting, factual short headline from this news article in telugu language\n <Your Telugu news article text here>"
|
134 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
135 |
+
|
136 |
+
outputs = model.generate(**input_ids)
|
137 |
+
print(tokenizer.decode(outputs[0]))
|
138 |
+
```
|
139 |
+
|
140 |
+
* _Using `torch.bfloat16`_
|
141 |
+
|
142 |
+
```python
|
143 |
+
# pip install accelerate
|
144 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
145 |
+
|
146 |
+
tokenizer = AutoTokenizer.from_pretrained("saidines12/telugu-news-headline-generation")
|
147 |
+
model = AutoModelForCausalLM.from_pretrained("saidines12/telugu-news-headline-generation", device_map="auto", torch_dtype=torch.bfloat16)
|
148 |
+
|
149 |
+
input_text = "Generate relevant, interesting, factual short headline from this news article in telugu language\n <Your Telugu news article text here>"
|
150 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
151 |
+
|
152 |
+
outputs = model.generate(**input_ids)
|
153 |
+
print(tokenizer.decode(outputs[0]))
|
154 |
+
```
|
155 |
+
|
156 |
+
#### Quantized Versions through `bitsandbytes`
|
157 |
+
|
158 |
+
* _Using 8-bit precision (int8)_
|
159 |
+
|
160 |
+
```python
|
161 |
+
# pip install bitsandbytes accelerate
|
162 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
163 |
+
|
164 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
165 |
+
|
166 |
+
tokenizer = AutoTokenizer.from_pretrained("saidines12/telugu-news-headline-generation")
|
167 |
+
model = AutoModelForCausalLM.from_pretrained("saidines12/telugu-news-headline-generation", quantization_config=quantization_config)
|
168 |
+
|
169 |
+
input_text = "Generate relevant, interesting, factual short headline from this news article in telugu language\n <Your Telugu news article text here>"
|
170 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
171 |
+
|
172 |
+
outputs = model.generate(**input_ids)
|
173 |
+
print(tokenizer.decode(outputs[0]))
|
174 |
+
```
|
175 |
+
|
176 |
+
* _Using 4-bit precision_
|
177 |
+
|
178 |
+
```python
|
179 |
+
# pip install bitsandbytes accelerate
|
180 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
181 |
+
|
182 |
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
183 |
+
|
184 |
+
tokenizer = AutoTokenizer.from_pretrained("saidines12/telugu-news-headline-generation")
|
185 |
+
model = AutoModelForCausalLM.from_pretrained("saidines12/telugu-news-headline-generation", quantization_config=quantization_config)
|
186 |
+
|
187 |
+
input_text = "Generate relevant, interesting, factual short headline from this news article in telugu language\n <Your Telugu news article text here>"
|
188 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
189 |
+
|
190 |
+
outputs = model.generate(**input_ids)
|
191 |
+
print(tokenizer.decode(outputs[0]))
|
192 |
+
```
|
193 |
+
|
194 |
+
|
195 |
+
#### Other optimizations
|
196 |
+
|
197 |
+
* _Flash Attention 2_
|
198 |
+
|
199 |
+
First make sure to install `flash-attn` in your environment `pip install flash-attn`
|
200 |
+
|
201 |
+
```diff
|
202 |
+
model = AutoModelForCausalLM.from_pretrained(
|
203 |
+
model_id,
|
204 |
+
torch_dtype=torch.float16,
|
205 |
+
+ attn_implementation="flash_attention_2"
|
206 |
+
).to(0)
|
207 |
+
```
|
208 |
+
|
209 |
+
### Inputs and outputs
|
210 |
+
|
211 |
+
* **Input:** Text string, such as a question, a prompt, or a document to be
|
212 |
+
summarized.
|
213 |
+
* **Output:** Generated English-language text in response to the input, such
|
214 |
+
as an answer to a question, or a summary of a document.
|
215 |
+
|
216 |
+
|
217 |
|
218 |
## Technical Specifications
|
219 |
|