Update app.py
Browse files
app.py
CHANGED
@@ -12,15 +12,6 @@ is_stopped = False
|
|
12 |
seed = random.randint(0,100000)
|
13 |
setup_seed(seed)
|
14 |
|
15 |
-
device = torch.device("cpu")
|
16 |
-
vocab_mlm = create_vocab()
|
17 |
-
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
18 |
-
save_path = 'mlm-model-27.pt'
|
19 |
-
train_seqs = pd.read_csv('C0_seq.csv')
|
20 |
-
train_seq = train_seqs['Seq'].tolist()
|
21 |
-
model = torch.load(save_path, map_location=torch.device('cpu'))
|
22 |
-
model = model.to(device)
|
23 |
-
|
24 |
def temperature_sampling(logits, temperature):
|
25 |
logits = logits / temperature
|
26 |
probabilities = torch.softmax(logits, dim=-1)
|
@@ -32,11 +23,20 @@ def stop_generation():
|
|
32 |
is_stopped = True
|
33 |
return "Generation stopped."
|
34 |
|
35 |
-
def CTXGen(X1, X2, τ, g_num, length_range):
|
36 |
global is_stopped
|
37 |
is_stopped = False
|
38 |
start, end = length_range
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
msa_data = pd.read_csv('conoData_C0.csv')
|
41 |
msa = msa_data['Sequences'].tolist()
|
42 |
msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
|
@@ -158,6 +158,7 @@ with gr.Blocks() as demo:
|
|
158 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
159 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
160 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
|
|
161 |
with gr.Row():
|
162 |
start_button = gr.Button("Start Generation")
|
163 |
stop_button = gr.Button("Stop Generation")
|
@@ -166,7 +167,7 @@ with gr.Blocks() as demo:
|
|
166 |
with gr.Row():
|
167 |
output_file = gr.File(label="Download generated conotoxins")
|
168 |
|
169 |
-
start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range], outputs=[output_df, output_file])
|
170 |
stop_button.click(stop_generation, outputs=None)
|
171 |
|
172 |
demo.launch()
|
|
|
12 |
seed = random.randint(0,100000)
|
13 |
setup_seed(seed)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def temperature_sampling(logits, temperature):
|
16 |
logits = logits / temperature
|
17 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
23 |
is_stopped = True
|
24 |
return "Generation stopped."
|
25 |
|
26 |
+
def CTXGen(X1, X2, τ, g_num, length_range, model_name):
|
27 |
global is_stopped
|
28 |
is_stopped = False
|
29 |
start, end = length_range
|
30 |
|
31 |
+
device = torch.device("cpu")
|
32 |
+
vocab_mlm = create_vocab()
|
33 |
+
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
34 |
+
save_path = model_name
|
35 |
+
train_seqs = pd.read_csv('C0_seq.csv')
|
36 |
+
train_seq = train_seqs['Seq'].tolist()
|
37 |
+
model = torch.load(save_path, map_location=torch.device('cpu'))
|
38 |
+
model = model.to(device)
|
39 |
+
|
40 |
msa_data = pd.read_csv('conoData_C0.csv')
|
41 |
msa = msa_data['Sequences'].tolist()
|
42 |
msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
|
|
|
158 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
159 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
160 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
161 |
+
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
|
162 |
with gr.Row():
|
163 |
start_button = gr.Button("Start Generation")
|
164 |
stop_button = gr.Button("Stop Generation")
|
|
|
167 |
with gr.Row():
|
168 |
output_file = gr.File(label="Download generated conotoxins")
|
169 |
|
170 |
+
start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name], outputs=[output_df, output_file])
|
171 |
stop_button.click(stop_generation, outputs=None)
|
172 |
|
173 |
demo.launch()
|