syn / src /generation /medical_generator.py
theaniketgiri's picture
� Initial commit to Hugging Face Space
32519eb
"""
Basic Medical Text Generator for Synthex MVP
Uses Hugging Face models and Gemini API
"""
import google.generativeai as genai
from transformers import pipeline
import random
import time
import json
from typing import List, Dict, Optional
import logging
from datetime import datetime
import os
import sys
# Setup logging with better formatting
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Get Gemini API key from environment variable
DEFAULT_GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', '')
class MedicalTextGenerator:
def __init__(self, gemini_api_key: Optional[str] = None):
"""Initialize the medical text generator"""
self.gemini_api_key = gemini_api_key or DEFAULT_GEMINI_API_KEY
if not self.gemini_api_key:
logger.warning("No Gemini API key provided. Using Hugging Face model only.")
self.hf_model = None
self.gemini_model = None
# Initialize models
self._setup_models()
# Medical record templates
self.templates = {
"clinical_note": self._get_clinical_note_template(),
"discharge_summary": self._get_discharge_summary_template(),
"lab_report": self._get_lab_report_template(),
"prescription": self._get_prescription_template(),
"patient_intake": self._get_patient_intake_template()
}
def _setup_models(self):
"""Setup Hugging Face and Gemini models"""
try:
# Setup Hugging Face model (free)
logger.info("Loading Hugging Face medical model...")
# Use text generation pipeline with a smaller model and CPU device
self.hf_generator = pipeline(
"text-generation",
model="distilgpt2",
max_length=512,
do_sample=True,
temperature=0.7,
device=-1, # Force CPU usage to avoid CUDA issues
truncation=True # Add truncation to avoid warnings
)
logger.info("Hugging Face model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Hugging Face model: {str(e)}")
self.hf_generator = None
logger.info("Falling back to template-based generation")
try:
# Setup Gemini (free tier)
if self.gemini_api_key:
genai.configure(api_key=self.gemini_api_key)
# List available models
for m in genai.list_models():
logger.info(f"Available model: {m.name}")
self.gemini_model = genai.GenerativeModel('gemini-pro')
logger.info("Gemini model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Gemini model: {str(e)}")
self.gemini_model = None
logger.info("Gemini API will not be available")
def generate_record(self, record_type: str, use_gemini: bool = False) -> Dict:
"""Generate a synthetic medical record"""
if record_type not in self.templates:
raise ValueError(f"Unknown record type: {record_type}")
template = self.templates[record_type]
content = None
# Try generation methods in order of preference
if use_gemini and self.gemini_model:
try:
content = self._generate_with_gemini(template)
logger.info("Successfully generated record using Gemini")
except Exception as e:
logger.error(f"Gemini generation failed: {str(e)}")
content = None
if content is None and self.hf_generator:
try:
content = self._generate_with_huggingface(template)
logger.info("Successfully generated record using Hugging Face")
except Exception as e:
logger.error(f"Hugging Face generation failed: {str(e)}")
content = None
if content is None:
try:
content = self._generate_with_template(template)
logger.info("Successfully generated record using template")
except Exception as e:
logger.error(f"Template generation failed: {str(e)}")
raise RuntimeError("All generation methods failed")
return {
"id": self._generate_id(),
"type": record_type,
"text": content,
"timestamp": datetime.now().isoformat(),
"source": "Gemini" if use_gemini and self.gemini_model else "Hugging Face" if self.hf_generator else "Template"
}
def _generate_with_gemini(self, template: str) -> str:
"""Generate text using Gemini API"""
try:
prompt = f"""
Generate a realistic but completely fictional medical record using this template:
{template}
Requirements:
- Use fictional patient names and details
- Include medically accurate terminology
- Make it realistic but not based on any real patient
- Include specific medical details and measurements
- Follow standard medical documentation format
"""
response = self.gemini_model.generate_content(prompt)
return response.text
except Exception as e:
logger.error(f"Gemini generation failed: {str(e)}")
raise
def _generate_with_huggingface(self, template: str) -> str:
"""Generate text using Hugging Face model"""
try:
# First fill the template with random values
fake_data = {
"patient_name": random.choice([
"John Smith", "Jane Doe", "Robert Johnson", "Mary Wilson", "Emily Clark",
"Michael Brown", "Linda Lee", "David Kim", "Sarah Patel", "James Chen"
]),
"age": random.randint(18, 90),
"gender": random.choice(["Male", "Female", "Other"]),
"chief_complaint": random.choice([
"chest pain", "shortness of breath", "abdominal pain", "headache",
"fever", "fatigue", "dizziness", "back pain", "cough", "palpitations"
]),
"blood_pressure": f"{random.randint(110, 160)}/{random.randint(60, 100)}",
"heart_rate": random.randint(55, 120),
"temperature": round(random.uniform(97.0, 104.0), 1),
"diagnosis": random.choice([
"Hypertension", "Type 2 Diabetes", "Pneumonia", "Migraine",
"Gastroenteritis", "Anxiety", "Asthma", "COVID-19", "Anemia", "Hyperlipidemia"
]),
"date": time.strftime("%Y-%m-%d"),
"address": random.choice([
"123 Main St", "456 Oak Ave", "789 Pine Rd", "101 Maple Dr", "202 Elm St"
]),
"phone": f"({random.randint(200,999)})-{random.randint(100,999)}-{random.randint(1000,9999)}",
"email": random.choice([
"[email protected]", "[email protected]", "[email protected]", "[email protected]"
]),
}
# Fill template with fake data
filled_template = template
for key, value in fake_data.items():
filled_template = filled_template.replace(f"{{{key}}}", str(value))
# Use the filled template as starting prompt
prompt = filled_template[:100] + "..."
# Generate text with explicit configuration
generated = self.hf_generator(
prompt,
max_length=400,
num_return_sequences=1,
pad_token_id=50256,
truncation=True
)
# Use the generated text
return generated[0]['generated_text']
except Exception as e:
logger.error(f"Hugging Face generation failed: {str(e)}")
logger.info("Falling back to template-based generation")
return self._generate_with_template(template)
def _generate_with_template(self, template: str) -> str:
"""Fallback: Generate using template with random values"""
try:
# Expanded fake data for more variety
fake_data = {
"patient_name": random.choice([
"John Smith", "Jane Doe", "Robert Johnson", "Mary Wilson", "Emily Clark",
"Michael Brown", "Linda Lee", "David Kim", "Sarah Patel", "James Chen"
]),
"age": random.randint(18, 90),
"gender": random.choice(["Male", "Female", "Other"]),
"chief_complaint": random.choice([
"chest pain", "shortness of breath", "abdominal pain", "headache",
"fever", "fatigue", "dizziness", "back pain", "cough", "palpitations"
]),
"blood_pressure": f"{random.randint(110, 160)}/{random.randint(60, 100)}",
"heart_rate": random.randint(55, 120),
"temperature": round(random.uniform(97.0, 104.0), 1),
"diagnosis": random.choice([
"Hypertension", "Type 2 Diabetes", "Pneumonia", "Migraine",
"Gastroenteritis", "Anxiety", "Asthma", "COVID-19", "Anemia", "Hyperlipidemia"
]),
"date": time.strftime("%Y-%m-%d"),
"address": random.choice([
"123 Main St", "456 Oak Ave", "789 Pine Rd", "101 Maple Dr", "202 Elm St"
]),
"phone": f"({random.randint(200,999)})-{random.randint(100,999)}-{random.randint(1000,9999)}",
"email": random.choice([
"[email protected]", "[email protected]", "[email protected]", "[email protected]"
]),
}
# Fill template with fake data
filled_template = template
for key, value in fake_data.items():
filled_template = filled_template.replace(f"{{{key}}}", str(value))
return filled_template
except Exception as e:
logger.error(f"Template generation failed: {str(e)}")
raise
def batch_generate(self, record_type: str, count: int = 10, use_gemini: bool = False) -> List[Dict]:
"""Generate multiple records"""
records = []
for i in range(count):
try:
record = self.generate_record(record_type, use_gemini)
records.append(record)
# Rate limiting for API calls
if use_gemini:
time.sleep(1) # Respect API limits
logger.info(f"Generated record {i+1}/{count}")
except Exception as e:
logger.error(f"Failed to generate record {i+1}: {str(e)}")
continue
return records
def _generate_id(self) -> str:
"""Generate unique record ID"""
return f"SYN-{int(time.time())}-{random.randint(1000, 9999)}"
def _get_clinical_note_template(self) -> str:
return """
CLINICAL NOTE
Patient: {patient_name}
Age: {age}
Gender: {gender}
Date: {date}
Chief Complaint:
{chief_complaint}
Vital Signs:
- Blood Pressure: {blood_pressure} mmHg
- Heart Rate: {heart_rate} bpm
- Temperature: {temperature}°F
Assessment:
{diagnosis}
Plan:
1. Follow-up in 2 weeks
2. Continue current medications
3. Monitor symptoms
Provider: Dr. Smith
"""
def _get_discharge_summary_template(self) -> str:
return """
DISCHARGE SUMMARY
Patient: {patient_name}
Age: {age}
Gender: {gender}
Admission Date: {date}
Discharge Date: {date}
Reason for Admission:
{chief_complaint}
Hospital Course:
Patient was admitted for {chief_complaint}. During hospitalization, patient was treated with appropriate medications and showed improvement.
Final Diagnosis:
{diagnosis}
Discharge Medications:
1. Medication A - 1 tablet daily
2. Medication B - 2 tablets twice daily
Follow-up:
- Primary Care Provider: Dr. Johnson
- Appointment: 2 weeks from discharge
Discharge Instructions:
1. Take medications as prescribed
2. Follow up with primary care provider
3. Call if symptoms worsen
Discharging Provider: Dr. Smith
"""
def _get_lab_report_template(self) -> str:
return """
LABORATORY REPORT
Patient: {patient_name}
Age: {age}
Gender: {gender}
Date: {date}
Test Results:
Complete Blood Count (CBC):
- White Blood Cells: {random.randint(4,11)} K/uL
- Red Blood Cells: {round(random.uniform(4.0,5.5),2)} M/uL
- Hemoglobin: {round(random.uniform(12.0,16.0),1)} g/dL
- Platelets: {random.randint(150,450)} K/uL
Basic Metabolic Panel:
- Glucose: {random.randint(70,140)} mg/dL
- BUN: {random.randint(7,20)} mg/dL
- Creatinine: {round(random.uniform(0.6,1.2),2)} mg/dL
Interpretation:
Results are within normal limits.
Lab Director: Dr. Wilson
"""
def _get_prescription_template(self) -> str:
return """
PRESCRIPTION
Patient: {patient_name}
Age: {age}
Gender: {gender}
Date: {date}
Prescription:
{diagnosis} - {random.choice(['Amoxicillin', 'Lisinopril', 'Metformin', 'Atorvastatin', 'Albuterol'])}
Dosage: {random.choice(['1 tablet', '2 tablets', '1 capsule'])} {random.choice(['daily', 'twice daily', 'three times daily'])}
Quantity: {random.randint(30,90)} tablets
Refills: {random.randint(0,3)}
Prescribing Provider: Dr. Smith
DEA Number: AB1234567
"""
def _get_patient_intake_template(self) -> str:
return """
PATIENT INTAKE FORM
Personal Information:
Name: {patient_name}
Age: {age}
Gender: {gender}
Address: {address}
Phone: {phone}
Email: {email}
Emergency Contact:
Name: {random.choice(['Spouse', 'Parent', 'Sibling'])} {patient_name.split()[0]}
Phone: {phone}
Relationship: {random.choice(['Spouse', 'Parent', 'Sibling'])}
Insurance Information:
Provider: {random.choice(['Blue Cross', 'Aetna', 'United Healthcare', 'Cigna'])}
Policy Number: {random.randint(100000000,999999999)}
Group Number: {random.randint(10000,99999)}
Medical History:
Chief Complaint: {chief_complaint}
Current Medications: {random.choice(['None', 'Aspirin', 'Metformin', 'Lisinopril'])}
Allergies: {random.choice(['None', 'Penicillin', 'Sulfa', 'Peanuts'])}
Vital Signs:
Blood Pressure: {blood_pressure} mmHg
Heart Rate: {heart_rate} bpm
Temperature: {temperature}°F
Intake Date: {date}
Intake Provider: Dr. Smith
"""
def main():
"""Test the generator"""
generator = MedicalTextGenerator()
# Test each record type
for record_type in generator.templates.keys():
print(f"\nGenerating {record_type}...")
record = generator.generate_record(record_type)
print(json.dumps(record, indent=2))
if __name__ == "__main__":
main()