project-asclepius / modules /orchestrator.py
mgbam's picture
Update modules/orchestrator.py
398c674 verified
raw
history blame
8.55 kB
# modules/orchestrator.py
"""
The Central Nervous System of Project Asclepius.
This module is the master conductor, orchestrating high-performance, asynchronous
workflows for each of the application's features. It intelligently sequences
calls to API clients and the Gemini handler to transform user queries into
comprehensive, synthesized reports.
"""
import asyncio
import aiohttp
from itertools import chain
from PIL import Image
# Import all our specialized tools
from . import gemini_handler, prompts, utils
# ==============================================================================
# CORRECTED LINES: The import path is now an absolute import from the project root.
# The leading dot '.' has been removed.
from api_clients import (
pubmed_client,
clinicaltrials_client,
openfda_client,
rxnorm_client
)
# ==============================================================================
# --- Internal Helper for Data Formatting ---
def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]:
"""
Takes the raw dictionary of API results and formats each entry into a
clean, readable string suitable for injection into a Gemini prompt.
Args:
api_results (dict): The dictionary of results from asyncio.gather.
Returns:
dict[str, str]: A dictionary with the same keys but formatted string values.
"""
formatted_strings = {}
# Format PubMed data
pubmed_data = api_results.get('pubmed', [])
if isinstance(pubmed_data, list) and pubmed_data:
lines = [f"- Title: {a.get('title', 'N/A')} (Journal: {a.get('journal', 'N/A')}, URL: {a.get('url')})" for a in pubmed_data]
formatted_strings['pubmed'] = "\n".join(lines)
else:
formatted_strings['pubmed'] = "No relevant review articles were found on PubMed for this query."
# Format Clinical Trials data
trials_data = api_results.get('trials', [])
if isinstance(trials_data, list) and trials_data:
lines = [f"- Title: {t.get('title', 'N/A')} (Status: {t.get('status', 'N/A')}, URL: {t.get('url')})" for t in trials_data]
formatted_strings['trials'] = "\n".join(lines)
else:
formatted_strings['trials'] = "No actively recruiting clinical trials were found matching this query."
# Format OpenFDA Adverse Events data
# This data often comes from multiple queries, so we flatten it.
fda_data = api_results.get('openfda', [])
if isinstance(fda_data, list):
# The result is a list of lists, so we flatten it
all_events = list(chain.from_iterable(filter(None, fda_data)))
if all_events:
lines = [f"- {evt['term']} (Reported {evt['count']} times)" for evt in all_events]
formatted_strings['openfda'] = "\n".join(lines)
else:
formatted_strings['openfda'] = "No specific adverse event data was found for this query."
else:
formatted_strings['openfda'] = "No specific adverse event data was found for this query."
# Format Vision analysis
vision_data = api_results.get('vision', "")
if isinstance(vision_data, str) and vision_data:
formatted_strings['vision'] = vision_data
elif isinstance(vision_data, Exception):
formatted_strings['vision'] = f"An error occurred during image analysis: {vision_data}"
else:
formatted_strings['vision'] = ""
return formatted_strings
# --- FEATURE 1: Symptom Synthesizer Pipeline ---
async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str:
"""The complete, asynchronous pipeline for the Symptom Synthesizer tab."""
if not user_query:
return "Please enter a symptom description or a medical question to begin."
# STEP 1: AI-Powered Concept Extraction
# Use Gemini to find the core medical terms in the user's natural language query.
term_prompt = prompts.get_term_extraction_prompt(user_query)
concepts_str = await gemini_handler.generate_text_response(term_prompt)
concepts = utils.safe_literal_eval(concepts_str)
if not isinstance(concepts, list) or not concepts:
concepts = [user_query] # Fallback to the raw query if parsing fails
# Use "OR" for a broader, more inclusive search across APIs
search_query = " OR ".join(f'"{c}"' for c in concepts)
# STEP 2: Massively Parallel Evidence Gathering
# Launch all API calls concurrently for maximum performance.
async with aiohttp.ClientSession() as session:
# Define the portfolio of data we need to collect
tasks = {
"pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3),
"trials": clinicaltrials_client.find_trials(session, search_query, max_results=3),
"openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts)),
}
# If an image is provided, add the vision analysis to our task portfolio
if image_input:
tasks["vision"] = gemini_handler.analyze_image_with_text(
"In the context of the user query, analyze this image objectively. Describe visual features like color, shape, texture, and patterns. Do not diagnose or offer medical advice.", image_input
)
# Execute all tasks and wait for them all to complete
raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
api_data = dict(zip(tasks.keys(), raw_results))
# STEP 3: Data Formatting
# Convert the raw JSON/list results into clean, prompt-ready strings.
formatted_data = _format_api_data_for_prompt(api_data)
# STEP 4: The Grand Synthesis
# Feed all the structured, evidence-based data into Gemini for the final report generation.
synthesis_prompt = prompts.get_synthesis_prompt(
user_query=user_query,
concepts=concepts,
pubmed_data=formatted_data['pubmed'],
trials_data=formatted_data['trials'],
fda_data=formatted_data['openfda'],
vision_analysis=formatted_data['vision']
)
final_report = await gemini_handler.generate_text_response(synthesis_prompt)
# STEP 5: Final Delivery
# Prepend the mandatory disclaimer to the AI-generated report.
return f"{prompts.DISCLAIMER}\n\n{final_report}"
# --- FEATURE 2: Drug Interaction & Safety Analyzer Pipeline ---
async def run_drug_interaction_analysis(drug_list_str: str) -> str:
"""The complete, asynchronous pipeline for the Drug Interaction Analyzer tab."""
if not drug_list_str:
return "Please enter a comma-separated list of medications."
drug_names = [name.strip() for name in drug_list_str.split(',') if name.strip()]
if len(drug_names) < 2:
return "Please enter at least two medications to check for interactions."
# STEP 1: Concurrent Drug Data Gathering
async with aiohttp.ClientSession() as session:
tasks = {
"interactions": rxnorm_client.run_interaction_check(drug_names),
"safety_profiles": asyncio.gather(*(openfda_client.get_safety_profile(session, name) for name in drug_names))
}
raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
api_data = dict(zip(tasks.keys(), raw_results))
# STEP 2: Data Formatting for AI Synthesis
interaction_data = api_data.get('interactions', [])
if isinstance(interaction_data, Exception):
interaction_data = [{"error": str(interaction_data)}]
safety_profiles = api_data.get('safety_profiles', [])
if isinstance(safety_profiles, Exception):
safety_profiles = [{"error": str(safety_profiles)}]
# Combine safety profiles with their drug names for clarity in the prompt
safety_data_dict = dict(zip(drug_names, safety_profiles))
# Format the complex data into clean strings
interaction_formatted = utils.format_list_as_markdown([str(i) for i in interaction_data]) if interaction_data else "No interactions found."
safety_formatted = "\n".join([f"Profile for {drug}: {profile}" for drug, profile in safety_data_dict.items()])
# STEP 3: AI-Powered Safety Briefing
synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt(
drug_names=drug_names,
interaction_data=interaction_formatted,
safety_data=safety_formatted
)
final_report = await gemini_handler.generate_text_response(synthesis_prompt)
# STEP 4: Final Delivery
return f"{prompts.DISCLAIMER}\n\n{final_report}"