|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
import logging
|
|
|
import re
|
|
|
from pathlib import Path
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
|
from flask import Flask, request, jsonify
|
|
|
from flask_cors import CORS
|
|
|
from dotenv import load_dotenv
|
|
|
from unstructured.partition.pdf import partition_pdf
|
|
|
|
|
|
|
|
|
from bloatectomy import bloatectomy
|
|
|
|
|
|
|
|
|
from langchain_groq import ChatGroq
|
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
|
|
|
|
|
|
from langgraph.graph import StateGraph, START, END
|
|
|
from typing_extensions import TypedDict, NotRequired
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
|
logger = logging.getLogger("health-agent")
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", r"D:\DEV PATEL\2025\HealthCareAI\reports"))
|
|
|
SSRI_FILE = Path(os.getenv("SSRI_FILE", r"D:\DEV PATEL\2025\HealthCareAI\medicationCategories\SSRI_list.txt"))
|
|
|
MISC_FILE = Path(os.getenv("MISC_FILE", r"D:\DEV PATEL\2025\HealthCareAI\medicationCategories\MISC_list.txt"))
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
|
|
|
|
|
|
|
|
|
llm = ChatGroq(
|
|
|
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
|
|
|
temperature=0.0,
|
|
|
max_tokens=None,
|
|
|
)
|
|
|
|
|
|
|
|
|
NODE_BASE_INSTRUCTIONS = """
|
|
|
You are HealthAI — a clinical assistant producing JSON for downstream processing.
|
|
|
Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays.
|
|
|
Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed.
|
|
|
"""
|
|
|
|
|
|
|
|
|
agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS)
|
|
|
agent_json_resolver = create_react_agent(model=llm, tools=[], prompt="""
|
|
|
You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks).
|
|
|
Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance.
|
|
|
""")
|
|
|
|
|
|
|
|
|
def extract_json_from_llm_response(raw_response: str) -> dict:
|
|
|
"""
|
|
|
Try extracting a JSON object from raw LLM text. Performs common cleanups seen in LLM outputs.
|
|
|
Raises JSONDecodeError if parsing still fails.
|
|
|
"""
|
|
|
|
|
|
md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
|
|
|
json_string = md.group(1).strip() if md else raw_response
|
|
|
|
|
|
|
|
|
first, last = json_string.find('{'), json_string.rfind('}')
|
|
|
if 0 <= first < last:
|
|
|
json_string = json_string[first:last+1]
|
|
|
|
|
|
|
|
|
json_string = re.sub(r'\b\w+\s*{', '{', json_string)
|
|
|
json_string = re.sub(r'"assistant"\s*:', '', json_string)
|
|
|
json_string = re.sub(r'\b(false|true)"', r'\1', json_string)
|
|
|
|
|
|
|
|
|
def _esc(m):
|
|
|
prefix, body = m.group(1), m.group(2)
|
|
|
return prefix + body.replace('"', r'\"')
|
|
|
json_string = re.sub(
|
|
|
r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)',
|
|
|
_esc,
|
|
|
json_string
|
|
|
)
|
|
|
|
|
|
|
|
|
json_string = re.sub(r',\s*(?=[}\],])', '', json_string)
|
|
|
json_string = re.sub(r',\s*,', ',', json_string)
|
|
|
|
|
|
|
|
|
ob, cb = json_string.count('{'), json_string.count('}')
|
|
|
if cb > ob:
|
|
|
excess = cb - ob
|
|
|
json_string = json_string.rstrip()[:-excess]
|
|
|
|
|
|
|
|
|
def _escape_newlines_in_strings(s: str) -> str:
|
|
|
return re.sub(
|
|
|
r'"((?:[^"\\]|\\.)*?)"',
|
|
|
lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"',
|
|
|
s,
|
|
|
flags=re.DOTALL
|
|
|
)
|
|
|
json_string = _escape_newlines_in_strings(json_string)
|
|
|
|
|
|
|
|
|
return json.loads(json_string)
|
|
|
|
|
|
|
|
|
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
|
|
|
"""
|
|
|
Uses the bloatectomy class to remove duplicates.
|
|
|
style: 'highlight'|'bold'|'remov' ; we use 'remov' to delete duplicates.
|
|
|
Returns cleaned text (single string).
|
|
|
"""
|
|
|
try:
|
|
|
b = bloatectomy(text, style=style, output="html")
|
|
|
tokens = getattr(b, "tokens", None)
|
|
|
if not tokens:
|
|
|
return text
|
|
|
return "\n".join(tokens)
|
|
|
except Exception:
|
|
|
logger.exception("Bloatectomy cleaning failed; returning original text")
|
|
|
return text
|
|
|
|
|
|
|
|
|
def readDrugs_from_file(path: Path):
|
|
|
if not path.exists():
|
|
|
return {}, []
|
|
|
txt = path.read_text(encoding="utf-8", errors="ignore")
|
|
|
generics = re.findall(r"^(.*?)\|", txt, re.MULTILINE)
|
|
|
generics = [g.lower() for g in generics if g]
|
|
|
lines = [ln.strip().lower() for ln in txt.splitlines() if ln.strip()]
|
|
|
return dict(zip(generics, lines)), generics
|
|
|
|
|
|
def addToDrugs_line(line: str, drugs_flags: List[int], listing: Dict[str,str], genList: List[str]) -> List[int]:
|
|
|
gen_index = {g:i for i,g in enumerate(genList)}
|
|
|
for generic, pattern_line in listing.items():
|
|
|
try:
|
|
|
if re.search(pattern_line, line, re.I):
|
|
|
idx = gen_index.get(generic)
|
|
|
if idx is not None:
|
|
|
drugs_flags[idx] = 1
|
|
|
except re.error:
|
|
|
continue
|
|
|
return drugs_flags
|
|
|
|
|
|
def extract_medications_from_text(text: str) -> List[str]:
|
|
|
ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE)
|
|
|
misc_map, misc_generics = readDrugs_from_file(MISC_FILE)
|
|
|
combined_map = {**ssri_map, **misc_map}
|
|
|
combined_generics = []
|
|
|
if ssri_generics:
|
|
|
combined_generics.extend(ssri_generics)
|
|
|
if misc_generics:
|
|
|
combined_generics.extend(misc_generics)
|
|
|
|
|
|
flags = [0]* len(combined_generics)
|
|
|
meds_found = set()
|
|
|
for ln in text.splitlines():
|
|
|
ln = ln.strip()
|
|
|
if not ln:
|
|
|
continue
|
|
|
if combined_map:
|
|
|
flags = addToDrugs_line(ln, flags, combined_map, combined_generics)
|
|
|
m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I)
|
|
|
if m:
|
|
|
meds_found.add(m.group(2).strip())
|
|
|
m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln)
|
|
|
for s in m2:
|
|
|
if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I):
|
|
|
meds_found.add(s.strip())
|
|
|
for i, f in enumerate(flags):
|
|
|
if f == 1:
|
|
|
meds_found.add(combined_generics[i])
|
|
|
return list(meds_found)
|
|
|
|
|
|
|
|
|
PATIENT_NODE_PROMPT = """
|
|
|
You will extract patientDetails from the provided document texts.
|
|
|
Return ONLY JSON with this exact shape:
|
|
|
{ "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} }
|
|
|
Fill fields using text evidence or leave empty strings.
|
|
|
"""
|
|
|
|
|
|
DOCTOR_NODE_PROMPT = """
|
|
|
You will extract doctorDetails found in the documents.
|
|
|
Return ONLY JSON with this exact shape:
|
|
|
{ "doctorDetails": {"referredBy": ""} }
|
|
|
"""
|
|
|
|
|
|
TEST_REPORT_NODE_PROMPT = """
|
|
|
You will extract per-test structured results from the documents.
|
|
|
Return ONLY JSON with this exact shape:
|
|
|
{
|
|
|
"reports": [
|
|
|
{
|
|
|
"testName": "",
|
|
|
"dateReported": "",
|
|
|
"timeReported": "",
|
|
|
"abnormalFindings": [
|
|
|
{"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""}
|
|
|
],
|
|
|
"interpretation": "",
|
|
|
"trends": []
|
|
|
}
|
|
|
]
|
|
|
}
|
|
|
- Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report.
|
|
|
- For result numeric parsing, prefer numeric values; if not numeric, keep original string.
|
|
|
- Use statuses: Low, High, Borderline, Positive, Negative, Normal.
|
|
|
"""
|
|
|
|
|
|
ANALYSIS_NODE_PROMPT = """
|
|
|
You will create an overallAnalysis based on the extracted reports (the agent will give you the 'reports' JSON).
|
|
|
Return ONLY JSON:
|
|
|
{ "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } }
|
|
|
Be conservative, evidence-based, and suggest follow-up steps for physicians.
|
|
|
"""
|
|
|
|
|
|
CONDITION_LOOP_NODE_PROMPT = """
|
|
|
Validation and condition node:
|
|
|
Input: partial JSON (patientDetails, doctorDetails, reports, overallAnalysis).
|
|
|
Task: Check required keys exist and that each report has at least testName and abnormalFindings list.
|
|
|
Return ONLY JSON:
|
|
|
{ "valid": true, "missing": [] }
|
|
|
If missing fields, list keys in 'missing'. Do NOT modify content.
|
|
|
"""
|
|
|
|
|
|
|
|
|
def call_node_agent(node_prompt: str, payload: dict) -> dict:
|
|
|
"""
|
|
|
Call the generic agent with a targeted node prompt and the payload.
|
|
|
Tries to parse JSON. If parsing fails, uses the JSON resolver agent once.
|
|
|
"""
|
|
|
try:
|
|
|
content = {
|
|
|
"prompt": node_prompt,
|
|
|
"payload": payload
|
|
|
}
|
|
|
resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]})
|
|
|
|
|
|
|
|
|
raw = None
|
|
|
if isinstance(resp, str):
|
|
|
raw = resp
|
|
|
elif hasattr(resp, "content"):
|
|
|
raw = resp.content
|
|
|
elif isinstance(resp, dict):
|
|
|
msgs = resp.get("messages")
|
|
|
if msgs:
|
|
|
last_msg = msgs[-1]
|
|
|
if isinstance(last_msg, str):
|
|
|
raw = last_msg
|
|
|
elif hasattr(last_msg, "content"):
|
|
|
raw = last_msg.content
|
|
|
elif isinstance(last_msg, dict):
|
|
|
raw = last_msg.get("content", "")
|
|
|
else:
|
|
|
raw = str(last_msg)
|
|
|
else:
|
|
|
raw = json.dumps(resp)
|
|
|
else:
|
|
|
raw = str(resp)
|
|
|
|
|
|
parsed = extract_json_from_llm_response(raw)
|
|
|
return parsed
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e)
|
|
|
try:
|
|
|
resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only."
|
|
|
r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]})
|
|
|
|
|
|
rtxt = None
|
|
|
if isinstance(r, str):
|
|
|
rtxt = r
|
|
|
elif hasattr(r, "content"):
|
|
|
rtxt = r.content
|
|
|
elif isinstance(r, dict):
|
|
|
msgs = r.get("messages")
|
|
|
if msgs:
|
|
|
last_msg = msgs[-1]
|
|
|
if isinstance(last_msg, str):
|
|
|
rtxt = last_msg
|
|
|
elif hasattr(last_msg, "content"):
|
|
|
rtxt = last_msg.content
|
|
|
elif isinstance(last_msg, dict):
|
|
|
rtxt = last_msg.get("content", "")
|
|
|
else:
|
|
|
rtxt = str(last_msg)
|
|
|
else:
|
|
|
rtxt = json.dumps(r)
|
|
|
else:
|
|
|
rtxt = str(r)
|
|
|
|
|
|
corrected = extract_json_from_llm_response(rtxt)
|
|
|
return corrected
|
|
|
except Exception as e2:
|
|
|
logger.exception("JSON resolver also failed: %s", e2)
|
|
|
return {}
|
|
|
|
|
|
|
|
|
class State(TypedDict):
|
|
|
patient_meta: NotRequired[Dict[str, Any]]
|
|
|
patient_id: str
|
|
|
documents: List[Dict[str, Any]]
|
|
|
medications: List[str]
|
|
|
patientDetails: NotRequired[Dict[str, Any]]
|
|
|
doctorDetails: NotRequired[Dict[str, Any]]
|
|
|
reports: NotRequired[List[Dict[str, Any]]]
|
|
|
overallAnalysis: NotRequired[Dict[str, Any]]
|
|
|
valid: NotRequired[bool]
|
|
|
missing: NotRequired[List[str]]
|
|
|
|
|
|
|
|
|
def patient_details_node(state: State) -> dict:
|
|
|
payload = {
|
|
|
"patient_meta": state.get("patient_meta", {}),
|
|
|
"documents": state.get("documents", []),
|
|
|
"medications": state.get("medications", [])
|
|
|
}
|
|
|
logger.info("Running patient_details_node")
|
|
|
out = call_node_agent(PATIENT_NODE_PROMPT, payload)
|
|
|
return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}}
|
|
|
|
|
|
def doctor_details_node(state: State) -> dict:
|
|
|
payload = {
|
|
|
"documents": state.get("documents", []),
|
|
|
"medications": state.get("medications", [])
|
|
|
}
|
|
|
logger.info("Running doctor_details_node")
|
|
|
out = call_node_agent(DOCTOR_NODE_PROMPT, payload)
|
|
|
return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}}
|
|
|
|
|
|
def test_report_node(state: State) -> dict:
|
|
|
payload = {
|
|
|
"documents": state.get("documents", []),
|
|
|
"medications": state.get("medications", [])
|
|
|
}
|
|
|
logger.info("Running test_report_node")
|
|
|
out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload)
|
|
|
return {"reports": out.get("reports", []) if isinstance(out, dict) else []}
|
|
|
|
|
|
def analysis_node(state: State) -> dict:
|
|
|
payload = {
|
|
|
"patientDetails": state.get("patientDetails", {}),
|
|
|
"doctorDetails": state.get("doctorDetails", {}),
|
|
|
"reports": state.get("reports", []),
|
|
|
"medications": state.get("medications", [])
|
|
|
}
|
|
|
logger.info("Running analysis_node")
|
|
|
out = call_node_agent(ANALYSIS_NODE_PROMPT, payload)
|
|
|
return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}}
|
|
|
|
|
|
def condition_loop_node(state: State) -> dict:
|
|
|
payload = {
|
|
|
"patientDetails": state.get("patientDetails", {}),
|
|
|
"doctorDetails": state.get("doctorDetails", {}),
|
|
|
"reports": state.get("reports", []),
|
|
|
"overallAnalysis": state.get("overallAnalysis", {})
|
|
|
}
|
|
|
logger.info("Running condition_loop_node (validation)")
|
|
|
out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload)
|
|
|
if isinstance(out, dict) and "valid" in out:
|
|
|
return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])}
|
|
|
missing = []
|
|
|
if not state.get("patientDetails"):
|
|
|
missing.append("patientDetails")
|
|
|
if not state.get("reports"):
|
|
|
missing.append("reports")
|
|
|
return {"valid": len(missing) == 0, "missing": missing}
|
|
|
|
|
|
|
|
|
graph_builder = StateGraph(State)
|
|
|
|
|
|
graph_builder.add_node("patient_details", patient_details_node)
|
|
|
graph_builder.add_node("doctor_details", doctor_details_node)
|
|
|
graph_builder.add_node("test_report", test_report_node)
|
|
|
graph_builder.add_node("analysis", analysis_node)
|
|
|
graph_builder.add_node("condition_loop", condition_loop_node)
|
|
|
|
|
|
graph_builder.add_edge(START, "patient_details")
|
|
|
graph_builder.add_edge("patient_details", "doctor_details")
|
|
|
graph_builder.add_edge("doctor_details", "test_report")
|
|
|
graph_builder.add_edge("test_report", "analysis")
|
|
|
graph_builder.add_edge("analysis", "condition_loop")
|
|
|
graph_builder.add_edge("condition_loop", END)
|
|
|
|
|
|
graph = graph_builder.compile()
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
|
static_folder = BASE_DIR / "static"
|
|
|
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
|
|
|
CORS(app)
|
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET"])
|
|
|
def serve_frontend():
|
|
|
try:
|
|
|
return app.send_static_file("frontend.html")
|
|
|
except Exception:
|
|
|
return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404
|
|
|
|
|
|
@app.route("/process_reports", methods=["POST"])
|
|
|
def process_reports():
|
|
|
data = request.get_json(force=True)
|
|
|
patient_id = data.get("patient_id")
|
|
|
filenames = data.get("filenames", [])
|
|
|
extra_patient_meta = data.get("patientDetails", {})
|
|
|
|
|
|
if not patient_id or not filenames:
|
|
|
return jsonify({"error": "missing patient_id or filenames"}), 400
|
|
|
|
|
|
patient_folder = REPORTS_ROOT / str(patient_id)
|
|
|
if not patient_folder.exists() or not patient_folder.is_dir():
|
|
|
return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404
|
|
|
|
|
|
documents = []
|
|
|
combined_text_parts = []
|
|
|
|
|
|
for fname in filenames:
|
|
|
file_path = patient_folder / fname
|
|
|
if not file_path.exists():
|
|
|
logger.warning("file not found: %s", file_path)
|
|
|
continue
|
|
|
try:
|
|
|
elements = partition_pdf(filename=str(file_path))
|
|
|
page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text])
|
|
|
except Exception:
|
|
|
logger.exception("Failed to parse PDF %s", file_path)
|
|
|
page_text = ""
|
|
|
cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
|
|
|
documents.append({
|
|
|
"filename": fname,
|
|
|
"raw_text": page_text,
|
|
|
"cleaned_text": cleaned
|
|
|
})
|
|
|
combined_text_parts.append(cleaned)
|
|
|
|
|
|
if not documents:
|
|
|
return jsonify({"error": "no valid documents found"}), 400
|
|
|
|
|
|
combined_text = "\n\n".join(combined_text_parts)
|
|
|
meds = extract_medications_from_text(combined_text)
|
|
|
|
|
|
initial_state = {
|
|
|
"patient_meta": extra_patient_meta,
|
|
|
"patient_id": patient_id,
|
|
|
"documents": documents,
|
|
|
"medications": meds
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
result_state = graph.invoke(initial_state)
|
|
|
|
|
|
|
|
|
if not result_state.get("valid", True):
|
|
|
missing = result_state.get("missing", [])
|
|
|
logger.info("Validation failed; missing keys: %s", missing)
|
|
|
if "patientDetails" in missing:
|
|
|
result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id}
|
|
|
if "reports" in missing:
|
|
|
result_state["reports"] = []
|
|
|
|
|
|
result_state.update(analysis_node(result_state))
|
|
|
|
|
|
cond = condition_loop_node(result_state)
|
|
|
result_state.update(cond)
|
|
|
|
|
|
safe_response = {
|
|
|
"patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}),
|
|
|
"doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}),
|
|
|
"reports": result_state.get("reports", []),
|
|
|
"overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}),
|
|
|
"_pre_extracted_medications": result_state.get("medications", []),
|
|
|
"_validation": {
|
|
|
"valid": result_state.get("valid", True),
|
|
|
"missing": result_state.get("missing", [])
|
|
|
}
|
|
|
}
|
|
|
return jsonify(safe_response), 200
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.exception("Node pipeline failed")
|
|
|
return jsonify({"error": "Node pipeline failed", "detail": str(e)}), 500
|
|
|
|
|
|
@app.route("/ping", methods=["GET"])
|
|
|
def ping():
|
|
|
return jsonify({"status": "ok"})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
port = int(os.getenv("PORT", 5000))
|
|
|
app.run(host="0.0.0.0", port=port, debug=True)
|
|
|
|