oucgc1996 commited on
Commit
541538c
·
verified ·
1 Parent(s): 5075097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -6,9 +6,11 @@ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
 
9
 
10
  is_stopped = False
11
 
 
12
  def temperature_sampling(logits, temperature):
13
  logits = logits / temperature
14
  probabilities = torch.softmax(logits, dim=-1)
@@ -20,6 +22,24 @@ def stop_generation():
20
  is_stopped = True
21
  return "Generation stopped."
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
24
  if seed =='random':
25
  seed = random.randint(0,100000)
@@ -106,30 +126,7 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
106
  padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
107
  input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
108
 
109
- gen_length = len(input_text)
110
- length = gen_length - sum(1 for x in input_text if x != '[MASK]')
111
- for i in range(length):
112
- if is_stopped:
113
- return "output.csv", pd.DataFrame()
114
-
115
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
116
- idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
117
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
118
- attn_idx = torch.tensor(attn_idx).to(device)
119
-
120
- mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
121
- mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
122
-
123
- logits = model(idx_seq,idx_msa, attn_idx)
124
- mask_logits = logits[0, mask_position.item(), :]
125
-
126
- predicted_token_id = temperature_sampling(mask_logits, τ)
127
-
128
- predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
129
- input_text[mask_position.item()] = predicted_token
130
- padded_seq[mask_position.item()] = predicted_token.strip()
131
- new_seq = padded_seq
132
- generated_seq = input_text
133
 
134
  generated_seq[1] = "[MASK]"
135
  input_ids = vocab_mlm.__getitem__(generated_seq)
 
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
9
+ import numba
10
 
11
  is_stopped = False
12
 
13
+ @numba.jit(nopython=True)
14
  def temperature_sampling(logits, temperature):
15
  logits = logits / temperature
16
  probabilities = torch.softmax(logits, dim=-1)
 
22
  is_stopped = True
23
  return "Generation stopped."
24
 
25
+ @numba.jit(nopython=True)
26
+ def generate_sequence(input_text, model, vocab_mlm, idx_msa, τ):
27
+ gen_length = len(input_text)
28
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
29
+ for i in range(length):
30
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
31
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
32
+
33
+ logits = model(idx_seq, idx_msa, attn_idx)
34
+ mask_logits = logits[0, mask_position.item(), :]
35
+
36
+ predicted_token_id = temperature_sampling(mask_logits, τ)
37
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
38
+ input_text[mask_position.item()] = predicted_token
39
+ padded_seq[mask_position.item()] = predicted_token.strip()
40
+ new_seq = padded_seq
41
+ return input_text
42
+
43
  def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
44
  if seed =='random':
45
  seed = random.randint(0,100000)
 
126
  padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
127
  input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
128
 
129
+ generated_seq = generate_sequence(input_text, model, vocab_mlm, idx_msa, τ)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  generated_seq[1] = "[MASK]"
132
  input_ids = vocab_mlm.__getitem__(generated_seq)