anaghanagesh commited on
Commit
04d92a2
·
verified ·
1 Parent(s): 000942b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -47
app.py CHANGED
@@ -4,9 +4,8 @@ import base64
4
  from io import BytesIO
5
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  from rdkit import Chem
7
- from rdkit.Chem import AllChem
8
  import torch
9
- import py3Dmol
10
 
11
  # Load models
12
  bio_gpt = pipeline("text-generation", model="microsoft/BioGPT-Large")
@@ -14,26 +13,28 @@ chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base
14
  chemberta_model = AutoModelForCausalLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
15
  compliance_qa = pipeline("question-answering", model="nlpaueb/legal-bert-base-uncased")
16
 
17
- # Functions
18
- def extract_insights(prompt):
19
- try:
20
- result = bio_gpt(prompt, max_length=200, do_sample=True)
21
- return result[0]['generated_text']
22
- except Exception as e:
23
- return f"Error: {str(e)}"
24
 
25
  def generate_molecule():
26
  sample_smiles = ["CCO", "CCN", "C1=CC=CC=C1", "C(C(=O)O)N", "CC(C)CC"]
27
  return random.choice(sample_smiles)
28
 
29
  def predict_properties(smiles):
30
- try:
31
- inputs = chemberta_tokenizer(smiles, return_tensors="pt")
32
- with torch.no_grad():
33
- outputs = chemberta_model(**inputs)
34
- return round(outputs.logits.mean().item(), 3)
35
- except Exception as e:
36
- return f"Error: {str(e)}"
 
 
 
 
37
 
38
  def mol_to_3d_html(smiles):
39
  try:
@@ -41,53 +42,66 @@ def mol_to_3d_html(smiles):
41
  mol = Chem.AddHs(mol)
42
  AllChem.EmbedMolecule(mol, AllChem.ETKDG())
43
  AllChem.UFFOptimizeMolecule(mol)
44
- block = Chem.MolToMolBlock(mol)
45
 
46
- view = py3Dmol.view(width=400, height=400)
47
- view.addModel(block, "mol")
48
- view.setStyle({"stick": {}})
49
- view.setBackgroundColor("white")
50
- view.zoomTo()
51
-
52
- return view._make_html()
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
- return f"<p>Error generating 3D structure: {str(e)}</p>"
55
 
56
- def check_compliance(question, context):
57
- try:
58
- return compliance_qa(question=question, context=context)['answer']
59
- except Exception as e:
60
- return f"Error: {str(e)}"
61
 
62
- # Main function
63
- def run_pipeline(disease, symptoms):
64
- insights = extract_insights(f"Recent treatments for {disease} with symptoms: {symptoms}")
65
  smiles = generate_molecule()
66
- score = predict_properties(smiles)
67
- mol3d_html = mol_to_3d_html(smiles)
68
- compliance = check_compliance(
69
- "What does FDA require for drug testing?",
70
- "FDA requires extensive testing for new drug candidates including Phase I, II, and III clinical trials."
 
 
 
 
 
 
 
71
  )
72
- return insights, smiles, mol3d_html, score, compliance
73
 
74
- # Gradio Interface
75
  demo = gr.Interface(
76
- fn=run_pipeline,
77
  inputs=[
78
  gr.Textbox(label="🦠 Disease", value="lung cancer"),
79
  gr.Textbox(label="🩺 Symptoms", value="shortness of breath, weight loss")
80
  ],
81
  outputs=[
82
  gr.Textbox(label="📜 Literature Insights"),
83
- gr.Textbox(label="🧪 Generated SMILES"),
84
- gr.HTML(label="🧬 3D Molecule Viewer"),
85
- gr.Textbox(label="📊 Molecular Property Score (ChemBERTa)"),
86
- gr.Textbox(label="⚖️ Legal Compliance (FDA)")
 
87
  ],
88
  title="🧬 AI-Driven Drug Discovery System",
89
- description="Input a disease and its symptoms to discover potential drug candidates using LLMs and molecule modeling."
90
  )
91
 
92
  demo.launch()
93
-
 
4
  from io import BytesIO
5
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  from rdkit import Chem
7
+ from rdkit.Chem import Draw, AllChem
8
  import torch
 
9
 
10
  # Load models
11
  bio_gpt = pipeline("text-generation", model="microsoft/BioGPT-Large")
 
13
  chemberta_model = AutoModelForCausalLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
14
  compliance_qa = pipeline("question-answering", model="nlpaueb/legal-bert-base-uncased")
15
 
16
+ # --- Functions ---
17
+ def extract_insights(disease, symptoms):
18
+ prompt = f"Recent treatments for {disease} with symptoms: {symptoms}"
19
+ result = bio_gpt(prompt, max_length=200, do_sample=True)
20
+ return result[0]['generated_text']
 
 
21
 
22
  def generate_molecule():
23
  sample_smiles = ["CCO", "CCN", "C1=CC=CC=C1", "C(C(=O)O)N", "CC(C)CC"]
24
  return random.choice(sample_smiles)
25
 
26
  def predict_properties(smiles):
27
+ inputs = chemberta_tokenizer(smiles, return_tensors="pt")
28
+ with torch.no_grad():
29
+ outputs = chemberta_model(**inputs)
30
+ return round(outputs.logits.mean().item(), 3)
31
+
32
+ def visualize_2d(smiles):
33
+ mol = Chem.MolFromSmiles(smiles)
34
+ img = Draw.MolToImage(mol, size=(300, 300))
35
+ buf = BytesIO()
36
+ img.save(buf, format="PNG")
37
+ return base64.b64encode(buf.getvalue()).decode()
38
 
39
  def mol_to_3d_html(smiles):
40
  try:
 
42
  mol = Chem.AddHs(mol)
43
  AllChem.EmbedMolecule(mol, AllChem.ETKDG())
44
  AllChem.UFFOptimizeMolecule(mol)
45
+ mol_block = Chem.MolToMolBlock(mol)
46
 
47
+ encoded_block = mol_block.replace("\n", "\\n")
48
+ viewer_div = f"""
49
+ <div id="molviewer" style="width: 400px; height: 400px;"></div>
50
+ <script src="https://3Dmol.org/build/3Dmol-min.js"></script>
51
+ <script>
52
+ let element = document.getElementById("molviewer");
53
+ let config = {{ backgroundColor: "white" }};
54
+ let viewer = $3Dmol.createViewer(element, config);
55
+ let molData = `{encoded_block}`;
56
+ viewer.addModel(molData, "mol");
57
+ viewer.setStyle({{}}, {{stick:{{}}}});
58
+ viewer.zoomTo();
59
+ viewer.render();
60
+ </script>
61
+ """
62
+ return viewer_div
63
  except Exception as e:
64
+ return f"<p>Error generating 3D molecule: {str(e)}</p>"
65
 
66
+ def check_compliance():
67
+ context = "FDA requires extensive testing for new drug candidates including Phase I, II, and III clinical trials."
68
+ question = "What does FDA require for drug testing?"
69
+ return compliance_qa(question=question, context=context)['answer']
 
70
 
71
+ # --- Gradio UI ---
72
+ def run_discovery(disease, symptoms):
73
+ insights = extract_insights(disease, symptoms)
74
  smiles = generate_molecule()
75
+ prop_score = predict_properties(smiles)
76
+ img_base64 = visualize_2d(smiles)
77
+ compliance = check_compliance()
78
+ html_3d = mol_to_3d_html(smiles)
79
+
80
+ return (
81
+ insights,
82
+ f"SMILES: {smiles}",
83
+ f"ChemBERTa Property Score: {prop_score}",
84
+ f"<img src='data:image/png;base64,{img_base64}' width='300'/>",
85
+ html_3d,
86
+ compliance
87
  )
 
88
 
 
89
  demo = gr.Interface(
90
+ fn=run_discovery,
91
  inputs=[
92
  gr.Textbox(label="🦠 Disease", value="lung cancer"),
93
  gr.Textbox(label="🩺 Symptoms", value="shortness of breath, weight loss")
94
  ],
95
  outputs=[
96
  gr.Textbox(label="📜 Literature Insights"),
97
+ gr.Textbox(label="🧪 SMILES String"),
98
+ gr.Textbox(label="🧬 Property Score"),
99
+ gr.HTML(label="🧫 2D Molecule Structure"),
100
+ gr.HTML(label="🔬 3D Molecule Viewer"),
101
+ gr.Textbox(label="⚖️ FDA Compliance Summary")
102
  ],
103
  title="🧬 AI-Driven Drug Discovery System",
104
+ description="Enter disease and symptoms to generate drug candidates using BioGPT, ChemBERTa, and LegalBERT. View 2D and animated 3D molecules!"
105
  )
106
 
107
  demo.launch()