Spaces:
Runtime error
Runtime error
igashov
commited on
Commit
·
aa9b17f
1
Parent(s):
eb031b7
n_steps
Browse files
app.py
CHANGED
|
@@ -15,7 +15,6 @@ from src.lightning import DDPM
|
|
| 15 |
from src.linker_size_lightning import SizeClassifier
|
| 16 |
|
| 17 |
N_SAMPLES = 5
|
| 18 |
-
N_STEPS = 10
|
| 19 |
|
| 20 |
parser = argparse.ArgumentParser()
|
| 21 |
parser.add_argument('--ip', type=str, default=None)
|
|
@@ -39,7 +38,6 @@ if not os.path.exists(diffusion_path):
|
|
| 39 |
link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
|
| 40 |
subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
|
| 41 |
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
|
| 42 |
-
ddpm.edm.T = N_STEPS
|
| 43 |
print('Loaded diffusion model')
|
| 44 |
|
| 45 |
|
|
@@ -111,7 +109,7 @@ def draw_sample(idx, out_files):
|
|
| 111 |
return output.IFRAME_TEMPLATE.format(html=html)
|
| 112 |
|
| 113 |
|
| 114 |
-
def generate(input_file):
|
| 115 |
if input_file is None:
|
| 116 |
return ''
|
| 117 |
|
|
@@ -155,6 +153,8 @@ def generate(input_file):
|
|
| 155 |
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
| 156 |
print('Created dataloader')
|
| 157 |
|
|
|
|
|
|
|
| 158 |
for data in dataloader:
|
| 159 |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
| 160 |
print('Generated linker')
|
|
@@ -188,7 +188,8 @@ with demo:
|
|
| 188 |
with gr.Column():
|
| 189 |
gr.Markdown('## Input Fragments')
|
| 190 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
| 191 |
-
input_file = gr.File(file_count='single', label='Input Fragments')
|
|
|
|
| 192 |
examples = gr.Dataset(
|
| 193 |
components=[gr.File(visible=False)],
|
| 194 |
samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
|
|
@@ -219,13 +220,13 @@ with demo:
|
|
| 219 |
outputs=[visualization],
|
| 220 |
)
|
| 221 |
examples.click(
|
| 222 |
-
fn=lambda idx: [f'examples/example_{idx+1}.sdf', show_input(f'examples/example_{idx+1}.sdf')],
|
| 223 |
inputs=[examples],
|
| 224 |
-
outputs=[input_file, visualization]
|
| 225 |
)
|
| 226 |
button.click(
|
| 227 |
fn=generate,
|
| 228 |
-
inputs=[input_file],
|
| 229 |
outputs=[visualization, output_files, samples],
|
| 230 |
)
|
| 231 |
samples.change(
|
|
|
|
| 15 |
from src.linker_size_lightning import SizeClassifier
|
| 16 |
|
| 17 |
N_SAMPLES = 5
|
|
|
|
| 18 |
|
| 19 |
parser = argparse.ArgumentParser()
|
| 20 |
parser.add_argument('--ip', type=str, default=None)
|
|
|
|
| 38 |
link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
|
| 39 |
subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
|
| 40 |
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
|
|
|
|
| 41 |
print('Loaded diffusion model')
|
| 42 |
|
| 43 |
|
|
|
|
| 109 |
return output.IFRAME_TEMPLATE.format(html=html)
|
| 110 |
|
| 111 |
|
| 112 |
+
def generate(input_file, n_steps):
|
| 113 |
if input_file is None:
|
| 114 |
return ''
|
| 115 |
|
|
|
|
| 153 |
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
| 154 |
print('Created dataloader')
|
| 155 |
|
| 156 |
+
ddpm.edm.T = n_steps
|
| 157 |
+
|
| 158 |
for data in dataloader:
|
| 159 |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
| 160 |
print('Generated linker')
|
|
|
|
| 188 |
with gr.Column():
|
| 189 |
gr.Markdown('## Input Fragments')
|
| 190 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
| 191 |
+
input_file = gr.File(file_count='single', label='Input Fragments')
|
| 192 |
+
n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Diffusion Steps", step=10)
|
| 193 |
examples = gr.Dataset(
|
| 194 |
components=[gr.File(visible=False)],
|
| 195 |
samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
|
|
|
|
| 220 |
outputs=[visualization],
|
| 221 |
)
|
| 222 |
examples.click(
|
| 223 |
+
fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, show_input(f'examples/example_{idx+1}.sdf')],
|
| 224 |
inputs=[examples],
|
| 225 |
+
outputs=[input_file, n_steps, visualization]
|
| 226 |
)
|
| 227 |
button.click(
|
| 228 |
fn=generate,
|
| 229 |
+
inputs=[input_file, n_steps],
|
| 230 |
outputs=[visualization, output_files, samples],
|
| 231 |
)
|
| 232 |
samples.change(
|