oucgc1996 commited on
Commit
05c30a5
·
verified ·
1 Parent(s): 1be6080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -85,7 +85,14 @@ def CTXGen(X0, X1, X2, τ, g_num):
85
  cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
86
  act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
87
 
 
88
  while count < gen_num:
 
 
 
 
 
 
89
  gen_len = len(X0)
90
  seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
91
  vocab_mlm.token_to_idx["X"] = 4
@@ -97,6 +104,9 @@ def CTXGen(X0, X1, X2, τ, g_num):
97
  length = gen_length - sum(1 for x in input_text if x != '[MASK]')
98
 
99
  for i in range(length):
 
 
 
100
  _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
101
  idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
102
  idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
 
85
  cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
86
  act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
87
 
88
+ start_time = time.time()
89
  while count < gen_num:
90
+ if is_stopped:
91
+ return pd.DataFrame(), "output.csv"
92
+
93
+ if time.time() - start_time > 1200:
94
+ break
95
+
96
  gen_len = len(X0)
97
  seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
98
  vocab_mlm.token_to_idx["X"] = 4
 
104
  length = gen_length - sum(1 for x in input_text if x != '[MASK]')
105
 
106
  for i in range(length):
107
+ if is_stopped:
108
+ return pd.DataFrame(), "output.csv"
109
+
110
  _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
111
  idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
112
  idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)