Spaces:
Sleeping
Sleeping
""" | |
Medical Data Collection Pipeline for Synthex MVP | |
Collects medical text from free sources for training data | |
""" | |
import requests | |
import pandas as pd | |
from datasets import load_dataset | |
import time | |
import json | |
from pathlib import Path | |
from typing import List, Dict, Any | |
import logging | |
import sys | |
from tqdm import tqdm | |
from bs4 import BeautifulSoup | |
import re | |
from datetime import datetime | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(sys.stdout), | |
logging.FileHandler('data_collection.log') | |
] | |
) | |
logger = logging.getLogger(__name__) | |
class MedicalDataCollector: | |
def __init__(self, output_dir: str = "data/raw"): | |
self.output_dir = Path(output_dir) | |
self.output_dir.mkdir(parents=True, exist_ok=True) | |
self.stats = { | |
"total_samples": 0, | |
"sources": {}, | |
"errors": [], | |
"start_time": datetime.now() | |
} | |
logger.info(f"Initialized MedicalDataCollector with output directory: {self.output_dir}") | |
def collect_huggingface_datasets(self) -> Dict[str, List]: | |
"""Collect medical datasets from Hugging Face Hub""" | |
# Only include datasets that are known to exist and are medical-related | |
datasets_to_collect = [ | |
"medical_questions_pairs", | |
"medalpaca/medical_meadow_medical_flashcards", | |
"gamino/wiki_medical_terms", | |
("pubmed_qa", "pqa_artificial") # pubmed_qa requires a config | |
] | |
collected_data = {} | |
for dataset_entry in tqdm(datasets_to_collect, desc="Collecting Hugging Face datasets"): | |
try: | |
if isinstance(dataset_entry, tuple): | |
dataset_name, config = dataset_entry | |
logger.info(f"Loading dataset: {dataset_name} with config: {config}") | |
dataset = load_dataset(dataset_name, config, split="train") | |
dataset_key = f"{dataset_name}_{config}" | |
else: | |
dataset_name = dataset_entry | |
logger.info(f"Loading dataset: {dataset_name}") | |
dataset = load_dataset(dataset_name, split="train") | |
dataset_key = dataset_name | |
# Convert to list of dictionaries | |
data_list = [] | |
for item in dataset: | |
processed_item = self._process_dataset_item(item) | |
if processed_item: | |
data_list.append(processed_item) | |
if data_list: | |
collected_data[dataset_key] = data_list | |
self.stats["sources"][dataset_key] = len(data_list) | |
self.stats["total_samples"] += len(data_list) | |
# Save to file | |
output_file = self.output_dir / f"{dataset_key.replace('/', '_')}.json" | |
with open(output_file, 'w', encoding='utf-8') as f: | |
json.dump(data_list, f, indent=2, ensure_ascii=False) | |
logger.info(f"Saved {len(data_list)} samples from {dataset_key} to {output_file}") | |
else: | |
logger.warning(f"No valid data found in dataset: {dataset_key}") | |
time.sleep(1) # Be respectful to APIs | |
except Exception as e: | |
error_msg = f"Failed to load {dataset_entry}: {str(e)}" | |
logger.error(error_msg, exc_info=True) | |
self.stats["errors"].append(error_msg) | |
continue | |
return collected_data | |
def collect_pubmed_abstracts(self, queries: List[str] = None, max_results: int = 1000) -> List[Dict]: | |
"""Collect PubMed abstracts via API""" | |
if queries is None: | |
queries = [ | |
"clinical notes", | |
"medical case reports", | |
"patient discharge summaries", | |
"medical laboratory reports", | |
"medical imaging reports" | |
] | |
all_abstracts = [] | |
for query in tqdm(queries, desc="Collecting PubMed abstracts"): | |
try: | |
abstracts = self._collect_pubmed_query(query, max_results) | |
all_abstracts.extend(abstracts) | |
self.stats["sources"]["pubmed_" + query.replace(" ", "_")] = len(abstracts) | |
self.stats["total_samples"] += len(abstracts) | |
except Exception as e: | |
error_msg = f"Failed to collect PubMed abstracts for {query}: {str(e)}" | |
logger.error(error_msg) | |
self.stats["errors"].append(error_msg) | |
continue | |
# Save all abstracts | |
if all_abstracts: | |
output_file = self.output_dir / "pubmed_abstracts.json" | |
with open(output_file, 'w', encoding='utf-8') as f: | |
json.dump(all_abstracts, f, indent=2, ensure_ascii=False) | |
return all_abstracts | |
def _collect_pubmed_query(self, query: str, max_results: int) -> List[Dict]: | |
"""Collect PubMed abstracts for a specific query""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" | |
search_url = f"{base_url}esearch.fcgi" | |
search_params = { | |
"db": "pubmed", | |
"term": query, | |
"retmax": max_results, | |
"retmode": "json", | |
"sort": "relevance" | |
} | |
try: | |
response = requests.get(search_url, params=search_params) | |
response.raise_for_status() # Raise exception for bad status codes | |
search_results = response.json() | |
# Check rate limits | |
rate_limit = int(response.headers.get('X-RateLimit-Limit', '3')) | |
rate_remaining = int(response.headers.get('X-RateLimit-Remaining', '0')) | |
logger.info(f"Rate limit: {rate_remaining}/{rate_limit} requests remaining") | |
if rate_remaining <= 1: | |
logger.warning("Rate limit nearly reached, waiting 60 seconds") | |
time.sleep(60) | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to fetch PubMed search results for query '{query}': {str(e)}") | |
return [] | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to parse PubMed search results for query '{query}': {str(e)}") | |
return [] | |
if "esearchresult" not in search_results: | |
logger.warning(f"No search results found for query '{query}'") | |
return [] | |
id_list = search_results["esearchresult"]["idlist"] | |
abstracts = [] | |
batch_size = 100 | |
for i in range(0, len(id_list), batch_size): | |
batch_ids = id_list[i:i+batch_size] | |
ids_str = ",".join(batch_ids) | |
fetch_url = f"{base_url}efetch.fcgi" | |
fetch_params = { | |
"db": "pubmed", | |
"id": ids_str, | |
"retmode": "xml" | |
} | |
try: | |
response = requests.get(fetch_url, params=fetch_params) | |
response.raise_for_status() | |
# Check rate limits | |
rate_limit = int(response.headers.get('X-RateLimit-Limit', '3')) | |
rate_remaining = int(response.headers.get('X-RateLimit-Remaining', '0')) | |
logger.info(f"Rate limit: {rate_remaining}/{rate_limit} requests remaining") | |
if rate_remaining <= 1: | |
logger.warning("Rate limit nearly reached, waiting 60 seconds") | |
time.sleep(60) | |
# Parse XML with proper features | |
soup = BeautifulSoup(response.text, 'lxml', features="xml") | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to fetch PubMed article batch {i//batch_size + 1}: {str(e)}") | |
continue | |
except Exception as e: | |
logger.error(f"Failed to parse PubMed article batch {i//batch_size + 1}: {str(e)}") | |
continue | |
for article in soup.find_all('PubmedArticle'): | |
try: | |
abstract = article.find('Abstract') | |
if abstract: | |
abstract_text = abstract.get_text().strip() | |
if len(abstract_text) > 100: # Filter out very short abstracts | |
title = article.find('ArticleTitle') | |
if not title: | |
continue | |
title_text = title.get_text().strip() | |
pub_date = article.find('PubDate') | |
year = "Unknown" | |
if pub_date and pub_date.find('Year'): | |
year = pub_date.find('Year').get_text().strip() | |
abstracts.append({ | |
"title": title_text, | |
"abstract": abstract_text, | |
"year": year, | |
"source": "pubmed", | |
"query": query | |
}) | |
except Exception as e: | |
logger.debug(f"Failed to process article in batch {i//batch_size + 1}: {str(e)}") | |
continue | |
# Always wait between batches to respect rate limits | |
time.sleep(1) | |
logger.info(f"Collected {len(abstracts)} abstracts for query '{query}'") | |
return abstracts | |
def create_training_dataset(self) -> pd.DataFrame: | |
"""Combine all collected data into training dataset""" | |
all_texts = [] | |
# Load all collected datasets | |
for json_file in tqdm(list(self.output_dir.glob("*.json")), desc="Processing collected data"): | |
try: | |
with open(json_file, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
# Extract text content | |
for item in data: | |
text_content = self._extract_text_content(item) | |
if text_content: | |
processed_text = self._clean_text(text_content) | |
if processed_text: | |
all_texts.append({ | |
"text": processed_text, | |
"source": json_file.stem, | |
"length": len(processed_text), | |
"type": self._determine_text_type(processed_text) | |
}) | |
except Exception as e: | |
error_msg = f"Failed to process {json_file}: {str(e)}" | |
logger.error(error_msg) | |
self.stats["errors"].append(error_msg) | |
continue | |
# Create DataFrame | |
df = pd.DataFrame(all_texts) | |
# Basic filtering | |
df = df[df['length'] > 100] # Remove very short texts | |
df = df[df['length'] < 5000] # Remove very long texts | |
# Remove duplicates | |
df = df.drop_duplicates(subset=['text']) | |
# Save processed dataset | |
output_file = self.output_dir.parent / "processed" / "training_data.csv" | |
output_file.parent.mkdir(exist_ok=True) | |
df.to_csv(output_file, index=False, encoding='utf-8') | |
# Update stats | |
self.stats["final_samples"] = len(df) | |
self.stats["text_types"] = df['type'].value_counts().to_dict() | |
logger.info(f"Created training dataset with {len(df)} samples") | |
return df | |
def _process_dataset_item(self, item: Dict) -> Dict: | |
"""Process and validate a dataset item""" | |
try: | |
# Extract text content | |
text = self._extract_text_content(item) | |
if not text or len(text) < 100: | |
return None | |
# Clean text | |
cleaned_text = self._clean_text(text) | |
if not cleaned_text: | |
return None | |
# Create processed item | |
processed = { | |
"text": cleaned_text, | |
"source": "huggingface", | |
"type": self._determine_text_type(cleaned_text) | |
} | |
# Add metadata if available | |
for key in ['title', 'question', 'answer', 'instruction']: | |
if key in item: | |
processed[key] = str(item[key]) | |
return processed | |
except Exception: | |
return None | |
def _extract_text_content(self, item: Dict) -> str: | |
"""Extract relevant text content from dataset item""" | |
# Common text fields in medical datasets | |
text_fields = ['text', 'content', 'abstract', 'question', 'answer', | |
'instruction', 'output', 'input', 'context'] | |
for field in text_fields: | |
if field in item and item[field]: | |
return str(item[field]) | |
# Fallback: combine multiple fields | |
combined_text = "" | |
for key, value in item.items(): | |
if isinstance(value, str) and len(value) > 20: | |
combined_text += f"{value} " | |
return combined_text.strip() | |
def _clean_text(self, text: str) -> str: | |
"""Clean and normalize text""" | |
if not text: | |
return "" | |
# Remove special characters and normalize whitespace | |
text = re.sub(r'[^\w\s.,;:!?()-]', ' ', text) | |
text = re.sub(r'\s+', ' ', text) | |
# Remove common noise | |
text = re.sub(r'http\S+', '', text) | |
text = re.sub(r'www\S+', '', text) | |
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) | |
return text.strip() | |
def _determine_text_type(self, text: str) -> str: | |
"""Determine the type of medical text""" | |
text = text.lower() | |
if any(term in text for term in ['discharge', 'summary', 'discharge summary']): | |
return 'discharge_summary' | |
elif any(term in text for term in ['lab', 'laboratory', 'test results']): | |
return 'lab_report' | |
elif any(term in text for term in ['prescription', 'medication', 'drug']): | |
return 'prescription' | |
elif any(term in text for term in ['question', 'answer', 'qa']): | |
return 'medical_qa' | |
else: | |
return 'clinical_note' | |
def generate_report(self) -> Dict: | |
"""Generate a report of the data collection process""" | |
# Convert all datetime objects to strings | |
for k, v in self.stats.items(): | |
if isinstance(v, datetime): | |
self.stats[k] = str(v) | |
self.stats["end_time"] = str(datetime.now()) | |
if isinstance(self.stats["start_time"], datetime): | |
self.stats["start_time"] = str(self.stats["start_time"]) | |
# Calculate duration as string | |
try: | |
start_dt = datetime.fromisoformat(self.stats["start_time"]) | |
end_dt = datetime.fromisoformat(self.stats["end_time"]) | |
self.stats["duration"] = str(end_dt - start_dt) | |
except Exception: | |
self.stats["duration"] = "unknown" | |
report_file = self.output_dir.parent / "reports" / "collection_report.json" | |
report_file.parent.mkdir(exist_ok=True) | |
with open(report_file, 'w', encoding='utf-8') as f: | |
json.dump(self.stats, f, indent=2, ensure_ascii=False) | |
return self.stats | |
def main(): | |
"""Run data collection pipeline""" | |
try: | |
collector = MedicalDataCollector() | |
# Collect from Hugging Face | |
logger.info("Starting Hugging Face dataset collection...") | |
hf_data = collector.collect_huggingface_datasets() | |
# Collect from PubMed | |
logger.info("Starting PubMed collection...") | |
pubmed_data = collector.collect_pubmed_abstracts() | |
# Create training dataset | |
logger.info("Creating training dataset...") | |
training_df = collector.create_training_dataset() | |
# Generate report | |
report = collector.generate_report() | |
# Print summary | |
logger.info("\nData Collection Summary:") | |
logger.info(f"Total samples collected: {report['total_samples']}") | |
logger.info(f"Final training samples: {report['final_samples']}") | |
logger.info(f"Duration: {report['duration']}") | |
logger.info("\nText types distribution:") | |
for type_, count in report['text_types'].items(): | |
logger.info(f"- {type_}: {count}") | |
if report['errors']: | |
logger.warning(f"\nEncountered {len(report['errors'])} errors during collection") | |
except Exception as e: | |
logger.error(f"Data collection failed: {str(e)}", exc_info=True) | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |