Spaces:
Running
on
Zero
Running
on
Zero
use torch autocast for llama model generation
Browse files
model.py
CHANGED
@@ -215,22 +215,24 @@ class SALMONN(nn.Module):
|
|
215 |
embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
|
216 |
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
234 |
|
235 |
output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
|
236 |
|
|
|
215 |
embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
|
216 |
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
|
217 |
|
218 |
+
from torch.cuda.amp import autocast
|
219 |
+
|
220 |
+
with autocast(device_type="cuda", dtype=torch.float16):
|
221 |
+
output = self.llama_model.generate(
|
222 |
+
inputs_embeds=embeds,
|
223 |
+
max_length=max_length,
|
224 |
+
num_beams=num_beams,
|
225 |
+
do_sample=do_sample,
|
226 |
+
min_length=min_length,
|
227 |
+
top_p=top_p,
|
228 |
+
repetition_penalty=repetition_penalty,
|
229 |
+
length_penalty=length_penalty,
|
230 |
+
temperature=temperature,
|
231 |
+
attention_mask=atts,
|
232 |
+
bos_token_id=self.llama_tokenizer.bos_token_id,
|
233 |
+
eos_token_id=self.llama_tokenizer.eos_token_id,
|
234 |
+
pad_token_id=self.llama_tokenizer.pad_token_id
|
235 |
+
)
|
236 |
|
237 |
output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
|
238 |
|