oucgc1996 commited on
Commit
f9915b0
·
verified ·
1 Parent(s): 25bc2a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -47,7 +47,6 @@ def CTXGen(X1, X2, τ, g_num, length_range):
47
 
48
  model.eval()
49
  with torch.no_grad():
50
- new_seq = None
51
  IDs = []
52
  generated_seqs = []
53
  generated_seqs_FINAL = []
@@ -64,6 +63,7 @@ def CTXGen(X1, X2, τ, g_num, length_range):
64
  '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
65
  start_time = time.time()
66
  while count < gen_num:
 
67
  if is_stopped:
68
  return pd.DataFrame(), "output.csv"
69
 
@@ -109,7 +109,7 @@ def CTXGen(X1, X2, τ, g_num, length_range):
109
  input_ids = vocab_mlm.__getitem__(generated_seq)
110
  logits = model(torch.tensor([input_ids]).to(device), idx_msa)
111
  cls_mask_logits = logits[0, 1, :]
112
- cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
113
 
114
  generated_seq[2] = "[MASK]"
115
  input_ids = vocab_mlm.__getitem__(generated_seq)
 
47
 
48
  model.eval()
49
  with torch.no_grad():
 
50
  IDs = []
51
  generated_seqs = []
52
  generated_seqs_FINAL = []
 
63
  '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
64
  start_time = time.time()
65
  while count < gen_num:
66
+ new_seq = None
67
  if is_stopped:
68
  return pd.DataFrame(), "output.csv"
69
 
 
109
  input_ids = vocab_mlm.__getitem__(generated_seq)
110
  logits = model(torch.tensor([input_ids]).to(device), idx_msa)
111
  cls_mask_logits = logits[0, 1, :]
112
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=85)
113
 
114
  generated_seq[2] = "[MASK]"
115
  input_ids = vocab_mlm.__getitem__(generated_seq)