Spaces:
Sleeping
Sleeping
# modules/orchestrator.py | |
""" | |
The main conductor. This module sequences the calls to APIs and the AI model. | |
It contains the core application logic for each feature tab, orchestrating | |
data fetching, processing, and AI synthesis. | |
""" | |
import asyncio | |
import aiohttp | |
import ast | |
from itertools import chain | |
from PIL import Image | |
# Import all our tools | |
from . import gemini_handler, prompts | |
from .api_clients import ( | |
umls_client, | |
pubmed_client, | |
clinicaltrials_client, | |
openfda_client, | |
rxnorm_client | |
) | |
# --- Helper function for formatting data for prompts --- | |
def _format_data_for_prompt(data: list | dict, source_name: str) -> str: | |
"""Converts API result lists/dicts into a clean string for Gemini prompts.""" | |
if not data: | |
return f"No data found from {source_name}." | |
report_lines = [f"--- Data from {source_name} ---"] | |
if isinstance(data, list): | |
for item in data: | |
report_lines.append(str(item)) | |
elif isinstance(data, dict): | |
for key, value in data.items(): | |
report_lines.append(f"{key}: {value}") | |
return "\n".join(report_lines) | |
# --- Main Orchestrator for the Symptom Synthesizer --- | |
async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str: | |
"""The complete, asynchronous pipeline for the Symptom Synthesizer tab.""" | |
if not user_query: | |
return "Please enter a symptom description or a medical question to begin." | |
# 1. Extract concepts with Gemini | |
term_extraction_prompt = prompts.get_term_extraction_prompt(user_query) | |
concepts_str = await gemini_handler.generate_text_response(term_extraction_prompt) | |
try: | |
concepts = ast.literal_eval(concepts_str) | |
if not isinstance(concepts, list) or not concepts: | |
concepts = [user_query] # Fallback | |
except (ValueError, SyntaxError): | |
concepts = [user_query] # Fallback | |
search_query = " OR ".join(concepts) | |
# 2. Gather all evidence concurrently | |
async with aiohttp.ClientSession() as session: | |
tasks = { | |
"pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3), | |
"trials": clinicaltrials_client.find_trials(session, search_query, max_results=3), | |
"openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts)) | |
} | |
if image_input: | |
tasks["vision"] = gemini_handler.analyze_image_with_text( | |
"Analyze this image in a medical context. Describe what you see objectively. Do not diagnose.", image_input | |
) | |
results = await asyncio.gather(*tasks.values(), return_exceptions=True) | |
api_data = dict(zip(tasks.keys(), results)) | |
# 3. Format all gathered data for the final prompt | |
pubmed_formatted = _format_data_for_prompt(api_data.get('pubmed'), "PubMed") | |
trials_formatted = _format_data_for_prompt(api_data.get('trials'), "ClinicalTrials.gov") | |
# Flatten the list of lists from the OpenFDA gather call | |
fda_results = list(chain.from_iterable(api_data.get('openfda', []))) | |
fda_formatted = _format_data_for_prompt(fda_results, "OpenFDA Adverse Events") | |
vision_formatted = api_data.get('vision', "") | |
if isinstance(vision_formatted, Exception): | |
vision_formatted = "Error analyzing image." | |
# 4. The Grand Synthesis with Gemini | |
synthesis_prompt = prompts.get_synthesis_prompt( | |
user_query=user_query, | |
concepts=concepts, | |
pubmed_data=pubmed_formatted, | |
trials_data=trials_formatted, | |
fda_data=fda_formatted, | |
vision_analysis=vision_formatted | |
) | |
final_report = await gemini_handler.generate_text_response(synthesis_prompt) | |
return f"{prompts.DISCLAIMER}\n\n{final_report}" | |
# --- Main Orchestrator for the Drug Interaction Analyzer --- | |
async def run_drug_interaction_analysis(drug_list_str: str) -> str: | |
"""The complete, asynchronous pipeline for the Drug Interaction Analyzer tab.""" | |
if not drug_list_str: | |
return "Please enter a comma-separated list of medications." | |
drug_names = [name.strip() for name in drug_list_str.split(',') if name.strip()] | |
if len(drug_names) < 2: | |
return "Please enter at least two medications to check for interactions." | |
# 1. Gather all drug data concurrently | |
async with aiohttp.ClientSession() as session: | |
tasks = { | |
"interactions": rxnorm_client.run_interaction_check(drug_names), | |
"safety_profiles": asyncio.gather(*(openfda_client.get_safety_profile(session, name) for name in drug_names)) | |
} | |
results = await asyncio.gather(*tasks.values(), return_exceptions=True) | |
api_data = dict(zip(tasks.keys(), results)) | |
# 2. Format data for the final prompt | |
interaction_data = api_data.get('interactions', []) | |
if isinstance(interaction_data, Exception): | |
interaction_data = [{"error": str(interaction_data)}] | |
safety_profiles = api_data.get('safety_profiles', []) | |
if isinstance(safety_profiles, Exception): | |
safety_profiles = [{"error": str(safety_profiles)}] | |
# Combine safety profiles with their drug names | |
safety_data_dict = dict(zip(drug_names, safety_profiles)) | |
interaction_formatted = _format_data_for_prompt(interaction_data, "RxNorm Interactions") | |
safety_formatted = _format_data_for_prompt(safety_data_dict, "OpenFDA Safety Profiles") | |
# 3. Synthesize the safety report with Gemini | |
synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt( | |
drug_names=drug_names, | |
interaction_data=interaction_formatted, | |
safety_data=safety_formatted | |
) | |
final_report = await gemini_handler.generate_text_response(synthesis_prompt) | |
return f"{prompts.DISCLAIMER}\n\n{final_report}" |