Spaces:
Sleeping
Sleeping
""" | |
Basic Medical Text Generator for Synthex MVP | |
Uses Hugging Face models and Gemini API | |
""" | |
import google.generativeai as genai | |
from transformers import pipeline | |
import random | |
import time | |
import json | |
from typing import List, Dict, Optional | |
import logging | |
from datetime import datetime | |
import os | |
import sys | |
# Setup logging with better formatting | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Get Gemini API key from environment variable | |
DEFAULT_GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', '') | |
class MedicalTextGenerator: | |
def __init__(self, gemini_api_key: Optional[str] = None): | |
"""Initialize the medical text generator""" | |
self.gemini_api_key = gemini_api_key or DEFAULT_GEMINI_API_KEY | |
if not self.gemini_api_key: | |
logger.warning("No Gemini API key provided. Using Hugging Face model only.") | |
self.hf_model = None | |
self.gemini_model = None | |
# Initialize models | |
self._setup_models() | |
# Medical record templates | |
self.templates = { | |
"clinical_note": self._get_clinical_note_template(), | |
"discharge_summary": self._get_discharge_summary_template(), | |
"lab_report": self._get_lab_report_template(), | |
"prescription": self._get_prescription_template(), | |
"patient_intake": self._get_patient_intake_template() | |
} | |
def _setup_models(self): | |
"""Setup Hugging Face and Gemini models""" | |
try: | |
# Setup Hugging Face model (free) | |
logger.info("Loading Hugging Face medical model...") | |
# Use text generation pipeline with a smaller model and CPU device | |
self.hf_generator = pipeline( | |
"text-generation", | |
model="distilgpt2", | |
max_length=512, | |
do_sample=True, | |
temperature=0.7, | |
device=-1, # Force CPU usage to avoid CUDA issues | |
truncation=True # Add truncation to avoid warnings | |
) | |
logger.info("Hugging Face model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load Hugging Face model: {str(e)}") | |
self.hf_generator = None | |
logger.info("Falling back to template-based generation") | |
try: | |
# Setup Gemini (free tier) | |
if self.gemini_api_key: | |
genai.configure(api_key=self.gemini_api_key) | |
# List available models | |
for m in genai.list_models(): | |
logger.info(f"Available model: {m.name}") | |
self.gemini_model = genai.GenerativeModel('gemini-pro') | |
logger.info("Gemini model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load Gemini model: {str(e)}") | |
self.gemini_model = None | |
logger.info("Gemini API will not be available") | |
def generate_record(self, record_type: str, use_gemini: bool = False) -> Dict: | |
"""Generate a synthetic medical record""" | |
if record_type not in self.templates: | |
raise ValueError(f"Unknown record type: {record_type}") | |
template = self.templates[record_type] | |
content = None | |
# Try generation methods in order of preference | |
if use_gemini and self.gemini_model: | |
try: | |
content = self._generate_with_gemini(template) | |
logger.info("Successfully generated record using Gemini") | |
except Exception as e: | |
logger.error(f"Gemini generation failed: {str(e)}") | |
content = None | |
if content is None and self.hf_generator: | |
try: | |
content = self._generate_with_huggingface(template) | |
logger.info("Successfully generated record using Hugging Face") | |
except Exception as e: | |
logger.error(f"Hugging Face generation failed: {str(e)}") | |
content = None | |
if content is None: | |
try: | |
content = self._generate_with_template(template) | |
logger.info("Successfully generated record using template") | |
except Exception as e: | |
logger.error(f"Template generation failed: {str(e)}") | |
raise RuntimeError("All generation methods failed") | |
return { | |
"id": self._generate_id(), | |
"type": record_type, | |
"text": content, | |
"timestamp": datetime.now().isoformat(), | |
"source": "Gemini" if use_gemini and self.gemini_model else "Hugging Face" if self.hf_generator else "Template" | |
} | |
def _generate_with_gemini(self, template: str) -> str: | |
"""Generate text using Gemini API""" | |
try: | |
prompt = f""" | |
Generate a realistic but completely fictional medical record using this template: | |
{template} | |
Requirements: | |
- Use fictional patient names and details | |
- Include medically accurate terminology | |
- Make it realistic but not based on any real patient | |
- Include specific medical details and measurements | |
- Follow standard medical documentation format | |
""" | |
response = self.gemini_model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
logger.error(f"Gemini generation failed: {str(e)}") | |
raise | |
def _generate_with_huggingface(self, template: str) -> str: | |
"""Generate text using Hugging Face model""" | |
try: | |
# First fill the template with random values | |
fake_data = { | |
"patient_name": random.choice([ | |
"John Smith", "Jane Doe", "Robert Johnson", "Mary Wilson", "Emily Clark", | |
"Michael Brown", "Linda Lee", "David Kim", "Sarah Patel", "James Chen" | |
]), | |
"age": random.randint(18, 90), | |
"gender": random.choice(["Male", "Female", "Other"]), | |
"chief_complaint": random.choice([ | |
"chest pain", "shortness of breath", "abdominal pain", "headache", | |
"fever", "fatigue", "dizziness", "back pain", "cough", "palpitations" | |
]), | |
"blood_pressure": f"{random.randint(110, 160)}/{random.randint(60, 100)}", | |
"heart_rate": random.randint(55, 120), | |
"temperature": round(random.uniform(97.0, 104.0), 1), | |
"diagnosis": random.choice([ | |
"Hypertension", "Type 2 Diabetes", "Pneumonia", "Migraine", | |
"Gastroenteritis", "Anxiety", "Asthma", "COVID-19", "Anemia", "Hyperlipidemia" | |
]), | |
"date": time.strftime("%Y-%m-%d"), | |
"address": random.choice([ | |
"123 Main St", "456 Oak Ave", "789 Pine Rd", "101 Maple Dr", "202 Elm St" | |
]), | |
"phone": f"({random.randint(200,999)})-{random.randint(100,999)}-{random.randint(1000,9999)}", | |
"email": random.choice([ | |
"[email protected]", "[email protected]", "[email protected]", "[email protected]" | |
]), | |
} | |
# Fill template with fake data | |
filled_template = template | |
for key, value in fake_data.items(): | |
filled_template = filled_template.replace(f"{{{key}}}", str(value)) | |
# Use the filled template as starting prompt | |
prompt = filled_template[:100] + "..." | |
# Generate text with explicit configuration | |
generated = self.hf_generator( | |
prompt, | |
max_length=400, | |
num_return_sequences=1, | |
pad_token_id=50256, | |
truncation=True | |
) | |
# Use the generated text | |
return generated[0]['generated_text'] | |
except Exception as e: | |
logger.error(f"Hugging Face generation failed: {str(e)}") | |
logger.info("Falling back to template-based generation") | |
return self._generate_with_template(template) | |
def _generate_with_template(self, template: str) -> str: | |
"""Fallback: Generate using template with random values""" | |
try: | |
# Expanded fake data for more variety | |
fake_data = { | |
"patient_name": random.choice([ | |
"John Smith", "Jane Doe", "Robert Johnson", "Mary Wilson", "Emily Clark", | |
"Michael Brown", "Linda Lee", "David Kim", "Sarah Patel", "James Chen" | |
]), | |
"age": random.randint(18, 90), | |
"gender": random.choice(["Male", "Female", "Other"]), | |
"chief_complaint": random.choice([ | |
"chest pain", "shortness of breath", "abdominal pain", "headache", | |
"fever", "fatigue", "dizziness", "back pain", "cough", "palpitations" | |
]), | |
"blood_pressure": f"{random.randint(110, 160)}/{random.randint(60, 100)}", | |
"heart_rate": random.randint(55, 120), | |
"temperature": round(random.uniform(97.0, 104.0), 1), | |
"diagnosis": random.choice([ | |
"Hypertension", "Type 2 Diabetes", "Pneumonia", "Migraine", | |
"Gastroenteritis", "Anxiety", "Asthma", "COVID-19", "Anemia", "Hyperlipidemia" | |
]), | |
"date": time.strftime("%Y-%m-%d"), | |
"address": random.choice([ | |
"123 Main St", "456 Oak Ave", "789 Pine Rd", "101 Maple Dr", "202 Elm St" | |
]), | |
"phone": f"({random.randint(200,999)})-{random.randint(100,999)}-{random.randint(1000,9999)}", | |
"email": random.choice([ | |
"[email protected]", "[email protected]", "[email protected]", "[email protected]" | |
]), | |
} | |
# Fill template with fake data | |
filled_template = template | |
for key, value in fake_data.items(): | |
filled_template = filled_template.replace(f"{{{key}}}", str(value)) | |
return filled_template | |
except Exception as e: | |
logger.error(f"Template generation failed: {str(e)}") | |
raise | |
def batch_generate(self, record_type: str, count: int = 10, use_gemini: bool = False) -> List[Dict]: | |
"""Generate multiple records""" | |
records = [] | |
for i in range(count): | |
try: | |
record = self.generate_record(record_type, use_gemini) | |
records.append(record) | |
# Rate limiting for API calls | |
if use_gemini: | |
time.sleep(1) # Respect API limits | |
logger.info(f"Generated record {i+1}/{count}") | |
except Exception as e: | |
logger.error(f"Failed to generate record {i+1}: {str(e)}") | |
continue | |
return records | |
def _generate_id(self) -> str: | |
"""Generate unique record ID""" | |
return f"SYN-{int(time.time())}-{random.randint(1000, 9999)}" | |
def _get_clinical_note_template(self) -> str: | |
return """ | |
CLINICAL NOTE | |
Patient: {patient_name} | |
Age: {age} | |
Gender: {gender} | |
Date: {date} | |
Chief Complaint: | |
{chief_complaint} | |
Vital Signs: | |
- Blood Pressure: {blood_pressure} mmHg | |
- Heart Rate: {heart_rate} bpm | |
- Temperature: {temperature}°F | |
Assessment: | |
{diagnosis} | |
Plan: | |
1. Follow-up in 2 weeks | |
2. Continue current medications | |
3. Monitor symptoms | |
Provider: Dr. Smith | |
""" | |
def _get_discharge_summary_template(self) -> str: | |
return """ | |
DISCHARGE SUMMARY | |
Patient: {patient_name} | |
Age: {age} | |
Gender: {gender} | |
Admission Date: {date} | |
Discharge Date: {date} | |
Reason for Admission: | |
{chief_complaint} | |
Hospital Course: | |
Patient was admitted for {chief_complaint}. During hospitalization, patient was treated with appropriate medications and showed improvement. | |
Final Diagnosis: | |
{diagnosis} | |
Discharge Medications: | |
1. Medication A - 1 tablet daily | |
2. Medication B - 2 tablets twice daily | |
Follow-up: | |
- Primary Care Provider: Dr. Johnson | |
- Appointment: 2 weeks from discharge | |
Discharge Instructions: | |
1. Take medications as prescribed | |
2. Follow up with primary care provider | |
3. Call if symptoms worsen | |
Discharging Provider: Dr. Smith | |
""" | |
def _get_lab_report_template(self) -> str: | |
return """ | |
LABORATORY REPORT | |
Patient: {patient_name} | |
Age: {age} | |
Gender: {gender} | |
Date: {date} | |
Test Results: | |
Complete Blood Count (CBC): | |
- White Blood Cells: {random.randint(4,11)} K/uL | |
- Red Blood Cells: {round(random.uniform(4.0,5.5),2)} M/uL | |
- Hemoglobin: {round(random.uniform(12.0,16.0),1)} g/dL | |
- Platelets: {random.randint(150,450)} K/uL | |
Basic Metabolic Panel: | |
- Glucose: {random.randint(70,140)} mg/dL | |
- BUN: {random.randint(7,20)} mg/dL | |
- Creatinine: {round(random.uniform(0.6,1.2),2)} mg/dL | |
Interpretation: | |
Results are within normal limits. | |
Lab Director: Dr. Wilson | |
""" | |
def _get_prescription_template(self) -> str: | |
return """ | |
PRESCRIPTION | |
Patient: {patient_name} | |
Age: {age} | |
Gender: {gender} | |
Date: {date} | |
Prescription: | |
{diagnosis} - {random.choice(['Amoxicillin', 'Lisinopril', 'Metformin', 'Atorvastatin', 'Albuterol'])} | |
Dosage: {random.choice(['1 tablet', '2 tablets', '1 capsule'])} {random.choice(['daily', 'twice daily', 'three times daily'])} | |
Quantity: {random.randint(30,90)} tablets | |
Refills: {random.randint(0,3)} | |
Prescribing Provider: Dr. Smith | |
DEA Number: AB1234567 | |
""" | |
def _get_patient_intake_template(self) -> str: | |
return """ | |
PATIENT INTAKE FORM | |
Personal Information: | |
Name: {patient_name} | |
Age: {age} | |
Gender: {gender} | |
Address: {address} | |
Phone: {phone} | |
Email: {email} | |
Emergency Contact: | |
Name: {random.choice(['Spouse', 'Parent', 'Sibling'])} {patient_name.split()[0]} | |
Phone: {phone} | |
Relationship: {random.choice(['Spouse', 'Parent', 'Sibling'])} | |
Insurance Information: | |
Provider: {random.choice(['Blue Cross', 'Aetna', 'United Healthcare', 'Cigna'])} | |
Policy Number: {random.randint(100000000,999999999)} | |
Group Number: {random.randint(10000,99999)} | |
Medical History: | |
Chief Complaint: {chief_complaint} | |
Current Medications: {random.choice(['None', 'Aspirin', 'Metformin', 'Lisinopril'])} | |
Allergies: {random.choice(['None', 'Penicillin', 'Sulfa', 'Peanuts'])} | |
Vital Signs: | |
Blood Pressure: {blood_pressure} mmHg | |
Heart Rate: {heart_rate} bpm | |
Temperature: {temperature}°F | |
Intake Date: {date} | |
Intake Provider: Dr. Smith | |
""" | |
def main(): | |
"""Test the generator""" | |
generator = MedicalTextGenerator() | |
# Test each record type | |
for record_type in generator.templates.keys(): | |
print(f"\nGenerating {record_type}...") | |
record = generator.generate_record(record_type) | |
print(json.dumps(record, indent=2)) | |
if __name__ == "__main__": | |
main() |