habdine commited on
Commit
b8cc45f
1 Parent(s): e1f8bd2

Update modeling_prot2text.py

Browse files
Files changed (1) hide show
  1. modeling_prot2text.py +2 -18
modeling_prot2text.py CHANGED
@@ -323,8 +323,8 @@ class Prot2TextModel(PreTrainedModel):
323
  tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
324
  encoder_outputs=encoder_state,
325
  use_cache=True,
326
- output_attentions=True,
327
- output_scores=True,
328
  return_dict_in_generate=True,
329
  encoder_attention_mask=inputs['attention_mask'],
330
  length_penalty=1.0,
@@ -333,22 +333,6 @@ class Prot2TextModel(PreTrainedModel):
333
  num_beams=1)
334
 
335
  generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
336
- print(tok_ids.get('scores')[0].size())
337
- m = torch.nn.Softmax()
338
- att_w = []
339
- print(len(gpdb.sequence[0]))
340
- score = 0
341
- for i in range(len(tok_ids.get('cross_attentions'))):
342
- att_w.append(torch.mul(tok_ids.get('cross_attentions')[i][-1].squeeze().mean(dim=0), inputs['attention_mask'][-1].squeeze())[:len(gpdb.sequence[0])].tolist())
343
- score += np.log(torch.max(m(tok_ids.get('scores')[i]).squeeze()).item())
344
- score = score / len(tok_ids.get('cross_attentions'))
345
- # print(str(score))
346
-
347
- # import seaborn as sns
348
- # import matplotlib.pylab as plt
349
- # plt.figure().set_figwidth(150)
350
- # ax = sns.heatmap(att_w, cmap="YlGnBu", robust=True, xticklabels=gpdb.sequence[0])#, yticklabels=generated[0])
351
- # plt.savefig("seaborn_plot.png")
352
 
353
  os.remove(structure_filename)
354
  os.remove(graph_filename)
 
323
  tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
324
  encoder_outputs=encoder_state,
325
  use_cache=True,
326
+ output_attentions=False,
327
+ output_scores=False,
328
  return_dict_in_generate=True,
329
  encoder_attention_mask=inputs['attention_mask'],
330
  length_penalty=1.0,
 
333
  num_beams=1)
334
 
335
  generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  os.remove(structure_filename)
338
  os.remove(graph_filename)