Update app.py
Browse files
app.py
CHANGED
@@ -12,15 +12,6 @@ is_stopped = False
|
|
12 |
seed = random.randint(0,100000)
|
13 |
setup_seed(seed)
|
14 |
|
15 |
-
device = torch.device("cpu")
|
16 |
-
vocab_mlm = create_vocab()
|
17 |
-
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
18 |
-
save_path = 'mlm-model-27.pt'
|
19 |
-
train_seqs = pd.read_csv('C0_seq.csv')
|
20 |
-
train_seq = train_seqs['Seq'].tolist()
|
21 |
-
model = torch.load(save_path, map_location=torch.device('cpu'))
|
22 |
-
model = model.to(device)
|
23 |
-
|
24 |
def temperature_sampling(logits, temperature):
|
25 |
logits = logits / temperature
|
26 |
probabilities = torch.softmax(logits, dim=-1)
|
@@ -32,7 +23,16 @@ def stop_generation():
|
|
32 |
is_stopped = True
|
33 |
return "Generation stopped."
|
34 |
|
35 |
-
def CTXGen(X0, X1, X2, τ, g_num):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
global is_stopped
|
37 |
is_stopped = False
|
38 |
X3 = "X" * len(X0)
|
@@ -178,6 +178,7 @@ with gr.Blocks() as demo:
|
|
178 |
'<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
|
179 |
'<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
|
180 |
X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
|
|
|
181 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
182 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
183 |
with gr.Row():
|
@@ -188,7 +189,7 @@ with gr.Blocks() as demo:
|
|
188 |
with gr.Row():
|
189 |
output_file = gr.File(label="Download generated conotoxins")
|
190 |
|
191 |
-
start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num], outputs=[output_df, output_file])
|
192 |
stop_button.click(stop_generation, outputs=None)
|
193 |
|
194 |
demo.launch()
|
|
|
12 |
seed = random.randint(0,100000)
|
13 |
setup_seed(seed)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def temperature_sampling(logits, temperature):
|
16 |
logits = logits / temperature
|
17 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
23 |
is_stopped = True
|
24 |
return "Generation stopped."
|
25 |
|
26 |
+
def CTXGen(X0, X1, X2, τ, g_num, model_name):
|
27 |
+
device = torch.device("cpu")
|
28 |
+
vocab_mlm = create_vocab()
|
29 |
+
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
30 |
+
save_path = model_name
|
31 |
+
train_seqs = pd.read_csv('C0_seq.csv')
|
32 |
+
train_seq = train_seqs['Seq'].tolist()
|
33 |
+
model = torch.load(save_path, map_location=torch.device('cpu'))
|
34 |
+
model = model.to(device)
|
35 |
+
|
36 |
global is_stopped
|
37 |
is_stopped = False
|
38 |
X3 = "X" * len(X0)
|
|
|
178 |
'<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
|
179 |
'<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
|
180 |
X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
|
181 |
+
model_name = gr.Dropdown(choices=['model_final','model_C1','model_C2','model_C3','model_C4','model_C5','model_mlm'], label="Model")
|
182 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
183 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
184 |
with gr.Row():
|
|
|
189 |
with gr.Row():
|
190 |
output_file = gr.File(label="Download generated conotoxins")
|
191 |
|
192 |
+
start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num, model_name], outputs=[output_df, output_file])
|
193 |
stop_button.click(stop_generation, outputs=None)
|
194 |
|
195 |
demo.launch()
|