GradientGuru commited on
Commit
92dd329
1 Parent(s): 85278de

cache alibi_mask for accelerate training

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +8 -2
modeling_baichuan.py CHANGED
@@ -249,7 +249,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
249
  self.gradient_checkpointing = config.gradient_checkpointing
250
  self.post_init()
251
  self.max_cache_pos = config.model_max_length
252
- self.first_run = True
 
253
 
254
  def get_input_embeddings(self):
255
  return self.embed_tokens
@@ -306,8 +307,13 @@ class BaichuanModel(BaichuanPreTrainedModel):
306
  if inputs_embeds is None:
307
  inputs_embeds = self.embed_tokens(input_ids)
308
 
 
 
 
 
 
 
309
 
310
- alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
311
  if attention_mask is not None:
312
  if len(attention_mask.shape) == 2:
313
  expanded_mask = attention_mask.to(alibi_mask.dtype)
 
249
  self.gradient_checkpointing = config.gradient_checkpointing
250
  self.post_init()
251
  self.max_cache_pos = config.model_max_length
252
+ self.first_run = True
253
+ self.alibi_mask = None
254
 
255
  def get_input_embeddings(self):
256
  return self.embed_tokens
 
307
  if inputs_embeds is None:
308
  inputs_embeds = self.embed_tokens(input_ids)
309
 
310
+ if self.training:
311
+ if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
312
+ self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
313
+ alibi_mask = self.alibi_mask
314
+ else:
315
+ alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
316
 
 
317
  if attention_mask is not None:
318
  if len(attention_mask.shape) == 2:
319
  expanded_mask = attention_mask.to(alibi_mask.dtype)