Update app.py
Browse files
app.py
CHANGED
@@ -6,9 +6,11 @@ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
|
6 |
import gradio as gr
|
7 |
from gradio_rangeslider import RangeSlider
|
8 |
import time
|
|
|
9 |
|
10 |
is_stopped = False
|
11 |
|
|
|
12 |
def temperature_sampling(logits, temperature):
|
13 |
logits = logits / temperature
|
14 |
probabilities = torch.softmax(logits, dim=-1)
|
@@ -20,6 +22,24 @@ def stop_generation():
|
|
20 |
is_stopped = True
|
21 |
return "Generation stopped."
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
24 |
if seed =='random':
|
25 |
seed = random.randint(0,100000)
|
@@ -106,30 +126,7 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
|
106 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
107 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
108 |
|
109 |
-
|
110 |
-
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
111 |
-
for i in range(length):
|
112 |
-
if is_stopped:
|
113 |
-
return "output.csv", pd.DataFrame()
|
114 |
-
|
115 |
-
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
116 |
-
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
117 |
-
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
118 |
-
attn_idx = torch.tensor(attn_idx).to(device)
|
119 |
-
|
120 |
-
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
121 |
-
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
122 |
-
|
123 |
-
logits = model(idx_seq,idx_msa, attn_idx)
|
124 |
-
mask_logits = logits[0, mask_position.item(), :]
|
125 |
-
|
126 |
-
predicted_token_id = temperature_sampling(mask_logits, τ)
|
127 |
-
|
128 |
-
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
129 |
-
input_text[mask_position.item()] = predicted_token
|
130 |
-
padded_seq[mask_position.item()] = predicted_token.strip()
|
131 |
-
new_seq = padded_seq
|
132 |
-
generated_seq = input_text
|
133 |
|
134 |
generated_seq[1] = "[MASK]"
|
135 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
|
|
6 |
import gradio as gr
|
7 |
from gradio_rangeslider import RangeSlider
|
8 |
import time
|
9 |
+
import numba
|
10 |
|
11 |
is_stopped = False
|
12 |
|
13 |
+
@numba.jit(nopython=True)
|
14 |
def temperature_sampling(logits, temperature):
|
15 |
logits = logits / temperature
|
16 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
22 |
is_stopped = True
|
23 |
return "Generation stopped."
|
24 |
|
25 |
+
@numba.jit(nopython=True)
|
26 |
+
def generate_sequence(input_text, model, vocab_mlm, idx_msa, τ):
|
27 |
+
gen_length = len(input_text)
|
28 |
+
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
29 |
+
for i in range(length):
|
30 |
+
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
31 |
+
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
32 |
+
|
33 |
+
logits = model(idx_seq, idx_msa, attn_idx)
|
34 |
+
mask_logits = logits[0, mask_position.item(), :]
|
35 |
+
|
36 |
+
predicted_token_id = temperature_sampling(mask_logits, τ)
|
37 |
+
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
38 |
+
input_text[mask_position.item()] = predicted_token
|
39 |
+
padded_seq[mask_position.item()] = predicted_token.strip()
|
40 |
+
new_seq = padded_seq
|
41 |
+
return input_text
|
42 |
+
|
43 |
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
44 |
if seed =='random':
|
45 |
seed = random.randint(0,100000)
|
|
|
126 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
127 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
128 |
|
129 |
+
generated_seq = generate_sequence(input_text, model, vocab_mlm, idx_msa, τ)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
generated_seq[1] = "[MASK]"
|
132 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|