oucgc1996 commited on
Commit
4e6a9e7
·
verified ·
1 Parent(s): 9008ac0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -12,15 +12,6 @@ is_stopped = False
12
  seed = random.randint(0,100000)
13
  setup_seed(seed)
14
 
15
- device = torch.device("cpu")
16
- vocab_mlm = create_vocab()
17
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
18
- save_path = 'mlm-model-27.pt'
19
- train_seqs = pd.read_csv('C0_seq.csv')
20
- train_seq = train_seqs['Seq'].tolist()
21
- model = torch.load(save_path, map_location=torch.device('cpu'))
22
- model = model.to(device)
23
-
24
  def temperature_sampling(logits, temperature):
25
  logits = logits / temperature
26
  probabilities = torch.softmax(logits, dim=-1)
@@ -32,11 +23,20 @@ def stop_generation():
32
  is_stopped = True
33
  return "Generation stopped."
34
 
35
- def CTXGen(X1, X2, τ, g_num, length_range):
36
  global is_stopped
37
  is_stopped = False
38
  start, end = length_range
39
 
 
 
 
 
 
 
 
 
 
40
  msa_data = pd.read_csv('conoData_C0.csv')
41
  msa = msa_data['Sequences'].tolist()
42
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
@@ -158,6 +158,7 @@ with gr.Blocks() as demo:
158
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
159
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
160
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
 
161
  with gr.Row():
162
  start_button = gr.Button("Start Generation")
163
  stop_button = gr.Button("Stop Generation")
@@ -166,7 +167,7 @@ with gr.Blocks() as demo:
166
  with gr.Row():
167
  output_file = gr.File(label="Download generated conotoxins")
168
 
169
- start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range], outputs=[output_df, output_file])
170
  stop_button.click(stop_generation, outputs=None)
171
 
172
  demo.launch()
 
12
  seed = random.randint(0,100000)
13
  setup_seed(seed)
14
 
 
 
 
 
 
 
 
 
 
15
  def temperature_sampling(logits, temperature):
16
  logits = logits / temperature
17
  probabilities = torch.softmax(logits, dim=-1)
 
23
  is_stopped = True
24
  return "Generation stopped."
25
 
26
+ def CTXGen(X1, X2, τ, g_num, length_range, model_name):
27
  global is_stopped
28
  is_stopped = False
29
  start, end = length_range
30
 
31
+ device = torch.device("cpu")
32
+ vocab_mlm = create_vocab()
33
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
34
+ save_path = model_name
35
+ train_seqs = pd.read_csv('C0_seq.csv')
36
+ train_seq = train_seqs['Seq'].tolist()
37
+ model = torch.load(save_path, map_location=torch.device('cpu'))
38
+ model = model.to(device)
39
+
40
  msa_data = pd.read_csv('conoData_C0.csv')
41
  msa = msa_data['Sequences'].tolist()
42
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
 
158
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
159
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
160
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
161
+ model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
162
  with gr.Row():
163
  start_button = gr.Button("Start Generation")
164
  stop_button = gr.Button("Stop Generation")
 
167
  with gr.Row():
168
  output_file = gr.File(label="Download generated conotoxins")
169
 
170
+ start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name], outputs=[output_df, output_file])
171
  stop_button.click(stop_generation, outputs=None)
172
 
173
  demo.launch()