Spaces:
Sleeping
Sleeping
""" | |
Synthex Medical Text Generator - MVP Streamlit App | |
Deploy this on Hugging Face Spaces for free hosting | |
""" | |
import streamlit as st | |
import json | |
import time | |
from datetime import datetime | |
import pandas as pd | |
import os | |
import sys | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Add src directory to Python path | |
sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) | |
# Import the medical generator | |
from src.generation.medical_generator import MedicalTextGenerator, DEFAULT_GEMINI_API_KEY | |
# Page config | |
st.set_page_config( | |
page_title="Synthex Medical Text Generator", | |
page_icon="π₯", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.main-header { | |
font-size: 3rem; | |
font-weight: bold; | |
color: #1f77b4; | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
.sub-header { | |
font-size: 1.5rem; | |
color: #666; | |
text-align: center; | |
margin-bottom: 3rem; | |
} | |
.record-container { | |
background-color: #f8f9fa; | |
padding: 1rem; | |
border-radius: 0.5rem; | |
border-left: 4px solid #1f77b4; | |
margin: 1rem 0; | |
} | |
.stats-container { | |
background-color: #e8f4fd; | |
padding: 1rem; | |
border-radius: 0.5rem; | |
margin: 1rem 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Initialize session state | |
if 'generated_records' not in st.session_state: | |
st.session_state.generated_records = [] | |
if 'total_generated' not in st.session_state: | |
st.session_state.total_generated = 0 | |
if 'generator' not in st.session_state: | |
st.session_state.generator = None | |
# Header | |
st.markdown('<div class="main-header">π₯ Synthex Medical Text Generator</div>', unsafe_allow_html=True) | |
st.markdown('<div class="sub-header">Generate synthetic medical records for AI training and testing</div>', unsafe_allow_html=True) | |
# Sidebar | |
with st.sidebar: | |
st.header("βοΈ Configuration") | |
# API Key input (pre-filled with environment variable if available) | |
gemini_api_key = st.text_input( | |
"Gemini API Key", | |
value=os.getenv('GEMINI_API_KEY', ''), | |
type="password", | |
help="Enter your Google Gemini API key for better generation quality" | |
) | |
# Record type selection | |
record_type = st.selectbox( | |
"Select Record Type", | |
["clinical_note", "discharge_summary", "lab_report", "prescription", "patient_intake"], | |
format_func=lambda x: x.replace("_", " ").title() | |
) | |
# Quantity | |
quantity = st.slider("Number of Records", 1, 20, 5) | |
# Generation method | |
use_gemini = st.checkbox( | |
"Use Gemini API", | |
value=bool(gemini_api_key), # Only default to True if API key is available | |
help="Uses Google Gemini API for better quality generation" | |
) | |
# Advanced options | |
with st.expander("Advanced Options"): | |
include_metadata = st.checkbox("Include Metadata", value=True) | |
export_format = st.selectbox("Export Format", ["JSON", "CSV", "TXT"]) | |
# Main content | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.header("π Generate Medical Records") | |
# Generation button | |
if st.button("π Generate Records", type="primary", use_container_width=True): | |
# Initialize generator if not already done | |
if st.session_state.generator is None: | |
try: | |
with st.spinner("Initializing medical text generator..."): | |
st.session_state.generator = MedicalTextGenerator(gemini_api_key=gemini_api_key) | |
except Exception as e: | |
st.error(f"Error initializing generator: {str(e)}") | |
st.stop() | |
# Generate records | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
generated_records = [] | |
for i in range(quantity): | |
status_text.text(f"Generating record {i+1} of {quantity}...") | |
progress_bar.progress((i + 1) / quantity) | |
try: | |
record = st.session_state.generator.generate_record(record_type, use_gemini=use_gemini) | |
generated_records.append(record) | |
# Rate limiting | |
if use_gemini: | |
time.sleep(1) | |
except Exception as e: | |
logger.error(f"Failed to generate record {i+1}: {str(e)}") | |
st.error(f"Failed to generate record {i+1}: {str(e)}") | |
continue | |
# Update session state | |
if generated_records: | |
st.session_state.generated_records.extend(generated_records) | |
st.session_state.total_generated += len(generated_records) | |
status_text.text("β Generation complete!") | |
progress_bar.progress(1.0) | |
st.success(f"Successfully generated {len(generated_records)} medical records!") | |
# Display generated records | |
if st.session_state.generated_records: | |
st.header("π Generated Records") | |
# Filters | |
col_filter1, col_filter2 = st.columns(2) | |
with col_filter1: | |
filter_type = st.selectbox( | |
"Filter by Type", | |
["All"] + list(set([r['type'] for r in st.session_state.generated_records])) | |
) | |
with col_filter2: | |
records_per_page = st.selectbox("Records per page", [5, 10, 20, 50]) | |
# Filter records | |
filtered_records = st.session_state.generated_records | |
if filter_type != "All": | |
filtered_records = [r for r in filtered_records if r['type'] == filter_type] | |
# Pagination | |
total_records = len(filtered_records) | |
total_pages = (total_records - 1) // records_per_page + 1 | |
if total_pages > 1: | |
page = st.selectbox("Page", range(1, total_pages + 1)) | |
start_idx = (page - 1) * records_per_page | |
end_idx = start_idx + records_per_page | |
page_records = filtered_records[start_idx:end_idx] | |
else: | |
page_records = filtered_records | |
# Display records | |
for i, record in enumerate(page_records): | |
with st.expander(f"Record {record['id']} - {record['type'].replace('_', ' ').title()}"): | |
if include_metadata: | |
col_meta1, col_meta2, col_meta3 = st.columns(3) | |
with col_meta1: | |
st.metric("Type", record['type'].replace('_', ' ').title()) | |
with col_meta2: | |
st.metric("Generated", record['timestamp']) | |
with col_meta3: | |
st.metric("Source", record['source']) | |
st.markdown('<div class="record-container">', unsafe_allow_html=True) | |
st.text_area("Content", record['text'], height=200, key=f"record_{i}") | |
st.markdown('</div>', unsafe_allow_html=True) | |
with col2: | |
st.header("π Statistics") | |
# Stats container | |
st.markdown('<div class="stats-container">', unsafe_allow_html=True) | |
# Total records | |
st.metric("Total Records Generated", st.session_state.total_generated) | |
# Record type distribution | |
if st.session_state.generated_records: | |
type_counts = pd.Series([r['type'] for r in st.session_state.generated_records]).value_counts() | |
st.subheader("Record Type Distribution") | |
st.bar_chart(type_counts) | |
# Export options | |
st.subheader("Export Data") | |
if st.session_state.generated_records: | |
if export_format == "JSON": | |
json_str = json.dumps(st.session_state.generated_records, indent=2) | |
st.download_button( | |
"Download JSON", | |
json_str, | |
file_name=f"medical_records_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
mime="application/json" | |
) | |
elif export_format == "CSV": | |
df = pd.DataFrame(st.session_state.generated_records) | |
csv = df.to_csv(index=False) | |
st.download_button( | |
"Download CSV", | |
csv, | |
file_name=f"medical_records_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", | |
mime="text/csv" | |
) | |
elif export_format == "TXT": | |
txt = "\n\n".join([f"Record {r['id']} ({r['type']}):\n{r['text']}" for r in st.session_state.generated_records]) | |
st.download_button( | |
"Download TXT", | |
txt, | |
file_name=f"medical_records_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", | |
mime="text/plain" | |
) | |
st.markdown('</div>', unsafe_allow_html=True) |