oucgc1996 commited on
Commit
15657d4
·
verified ·
1 Parent(s): 05c30a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -12,15 +12,6 @@ is_stopped = False
12
  seed = random.randint(0,100000)
13
  setup_seed(seed)
14
 
15
- device = torch.device("cpu")
16
- vocab_mlm = create_vocab()
17
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
18
- save_path = 'mlm-model-27.pt'
19
- train_seqs = pd.read_csv('C0_seq.csv')
20
- train_seq = train_seqs['Seq'].tolist()
21
- model = torch.load(save_path, map_location=torch.device('cpu'))
22
- model = model.to(device)
23
-
24
  def temperature_sampling(logits, temperature):
25
  logits = logits / temperature
26
  probabilities = torch.softmax(logits, dim=-1)
@@ -32,7 +23,16 @@ def stop_generation():
32
  is_stopped = True
33
  return "Generation stopped."
34
 
35
- def CTXGen(X0, X1, X2, τ, g_num):
 
 
 
 
 
 
 
 
 
36
  global is_stopped
37
  is_stopped = False
38
  X3 = "X" * len(X0)
@@ -178,6 +178,7 @@ with gr.Blocks() as demo:
178
  '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
179
  '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
180
  X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
 
181
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
182
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
183
  with gr.Row():
@@ -188,7 +189,7 @@ with gr.Blocks() as demo:
188
  with gr.Row():
189
  output_file = gr.File(label="Download generated conotoxins")
190
 
191
- start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num], outputs=[output_df, output_file])
192
  stop_button.click(stop_generation, outputs=None)
193
 
194
  demo.launch()
 
12
  seed = random.randint(0,100000)
13
  setup_seed(seed)
14
 
 
 
 
 
 
 
 
 
 
15
  def temperature_sampling(logits, temperature):
16
  logits = logits / temperature
17
  probabilities = torch.softmax(logits, dim=-1)
 
23
  is_stopped = True
24
  return "Generation stopped."
25
 
26
+ def CTXGen(X0, X1, X2, τ, g_num, model_name):
27
+ device = torch.device("cpu")
28
+ vocab_mlm = create_vocab()
29
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
30
+ save_path = model_name
31
+ train_seqs = pd.read_csv('C0_seq.csv')
32
+ train_seq = train_seqs['Seq'].tolist()
33
+ model = torch.load(save_path, map_location=torch.device('cpu'))
34
+ model = model.to(device)
35
+
36
  global is_stopped
37
  is_stopped = False
38
  X3 = "X" * len(X0)
 
178
  '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
179
  '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
180
  X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
181
+ model_name = gr.Dropdown(choices=['model_final','model_C1','model_C2','model_C3','model_C4','model_C5','model_mlm'], label="Model")
182
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
183
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
184
  with gr.Row():
 
189
  with gr.Row():
190
  output_file = gr.File(label="Download generated conotoxins")
191
 
192
+ start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num, model_name], outputs=[output_df, output_file])
193
  stop_button.click(stop_generation, outputs=None)
194
 
195
  demo.launch()