fffiloni commited on
Commit
5cb656f
·
verified ·
1 Parent(s): a784a45

use torch autocast for llama model generation

Browse files
Files changed (1) hide show
  1. model.py +18 -16
model.py CHANGED
@@ -215,22 +215,24 @@ class SALMONN(nn.Module):
215
  embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
216
  atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
217
 
218
- # generate
219
- output = self.llama_model.generate(
220
- inputs_embeds=embeds,
221
- max_length=max_length,
222
- num_beams=num_beams,
223
- do_sample=do_sample,
224
- min_length=min_length,
225
- top_p=top_p,
226
- repetition_penalty=repetition_penalty,
227
- length_penalty=length_penalty,
228
- temperature=temperature,
229
- attention_mask=atts,
230
- bos_token_id=self.llama_tokenizer.bos_token_id,
231
- eos_token_id=self.llama_tokenizer.eos_token_id,
232
- pad_token_id=self.llama_tokenizer.pad_token_id
233
- )
 
 
234
 
235
  output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
236
 
 
215
  embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
216
  atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
217
 
218
+ from torch.cuda.amp import autocast
219
+
220
+ with autocast(device_type="cuda", dtype=torch.float16):
221
+ output = self.llama_model.generate(
222
+ inputs_embeds=embeds,
223
+ max_length=max_length,
224
+ num_beams=num_beams,
225
+ do_sample=do_sample,
226
+ min_length=min_length,
227
+ top_p=top_p,
228
+ repetition_penalty=repetition_penalty,
229
+ length_penalty=length_penalty,
230
+ temperature=temperature,
231
+ attention_mask=atts,
232
+ bos_token_id=self.llama_tokenizer.bos_token_id,
233
+ eos_token_id=self.llama_tokenizer.eos_token_id,
234
+ pad_token_id=self.llama_tokenizer.pad_token_id
235
+ )
236
 
237
  output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
238