""" Medical Data Collection Pipeline for Synthex MVP Collects medical text from free sources for training data """ import requests import pandas as pd from datasets import load_dataset import time import json from pathlib import Path from typing import List, Dict, Any import logging import sys from tqdm import tqdm from bs4 import BeautifulSoup import re from datetime import datetime # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('data_collection.log') ] ) logger = logging.getLogger(__name__) class MedicalDataCollector: def __init__(self, output_dir: str = "data/raw"): self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.stats = { "total_samples": 0, "sources": {}, "errors": [], "start_time": datetime.now() } logger.info(f"Initialized MedicalDataCollector with output directory: {self.output_dir}") def collect_huggingface_datasets(self) -> Dict[str, List]: """Collect medical datasets from Hugging Face Hub""" # Only include datasets that are known to exist and are medical-related datasets_to_collect = [ "medical_questions_pairs", "medalpaca/medical_meadow_medical_flashcards", "gamino/wiki_medical_terms", ("pubmed_qa", "pqa_artificial") # pubmed_qa requires a config ] collected_data = {} for dataset_entry in tqdm(datasets_to_collect, desc="Collecting Hugging Face datasets"): try: if isinstance(dataset_entry, tuple): dataset_name, config = dataset_entry logger.info(f"Loading dataset: {dataset_name} with config: {config}") dataset = load_dataset(dataset_name, config, split="train") dataset_key = f"{dataset_name}_{config}" else: dataset_name = dataset_entry logger.info(f"Loading dataset: {dataset_name}") dataset = load_dataset(dataset_name, split="train") dataset_key = dataset_name # Convert to list of dictionaries data_list = [] for item in dataset: processed_item = self._process_dataset_item(item) if processed_item: data_list.append(processed_item) if data_list: collected_data[dataset_key] = data_list self.stats["sources"][dataset_key] = len(data_list) self.stats["total_samples"] += len(data_list) # Save to file output_file = self.output_dir / f"{dataset_key.replace('/', '_')}.json" with open(output_file, 'w', encoding='utf-8') as f: json.dump(data_list, f, indent=2, ensure_ascii=False) logger.info(f"Saved {len(data_list)} samples from {dataset_key} to {output_file}") else: logger.warning(f"No valid data found in dataset: {dataset_key}") time.sleep(1) # Be respectful to APIs except Exception as e: error_msg = f"Failed to load {dataset_entry}: {str(e)}" logger.error(error_msg, exc_info=True) self.stats["errors"].append(error_msg) continue return collected_data def collect_pubmed_abstracts(self, queries: List[str] = None, max_results: int = 1000) -> List[Dict]: """Collect PubMed abstracts via API""" if queries is None: queries = [ "clinical notes", "medical case reports", "patient discharge summaries", "medical laboratory reports", "medical imaging reports" ] all_abstracts = [] for query in tqdm(queries, desc="Collecting PubMed abstracts"): try: abstracts = self._collect_pubmed_query(query, max_results) all_abstracts.extend(abstracts) self.stats["sources"]["pubmed_" + query.replace(" ", "_")] = len(abstracts) self.stats["total_samples"] += len(abstracts) except Exception as e: error_msg = f"Failed to collect PubMed abstracts for {query}: {str(e)}" logger.error(error_msg) self.stats["errors"].append(error_msg) continue # Save all abstracts if all_abstracts: output_file = self.output_dir / "pubmed_abstracts.json" with open(output_file, 'w', encoding='utf-8') as f: json.dump(all_abstracts, f, indent=2, ensure_ascii=False) return all_abstracts def _collect_pubmed_query(self, query: str, max_results: int) -> List[Dict]: """Collect PubMed abstracts for a specific query""" base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" search_url = f"{base_url}esearch.fcgi" search_params = { "db": "pubmed", "term": query, "retmax": max_results, "retmode": "json", "sort": "relevance" } try: response = requests.get(search_url, params=search_params) response.raise_for_status() # Raise exception for bad status codes search_results = response.json() # Check rate limits rate_limit = int(response.headers.get('X-RateLimit-Limit', '3')) rate_remaining = int(response.headers.get('X-RateLimit-Remaining', '0')) logger.info(f"Rate limit: {rate_remaining}/{rate_limit} requests remaining") if rate_remaining <= 1: logger.warning("Rate limit nearly reached, waiting 60 seconds") time.sleep(60) except requests.exceptions.RequestException as e: logger.error(f"Failed to fetch PubMed search results for query '{query}': {str(e)}") return [] except json.JSONDecodeError as e: logger.error(f"Failed to parse PubMed search results for query '{query}': {str(e)}") return [] if "esearchresult" not in search_results: logger.warning(f"No search results found for query '{query}'") return [] id_list = search_results["esearchresult"]["idlist"] abstracts = [] batch_size = 100 for i in range(0, len(id_list), batch_size): batch_ids = id_list[i:i+batch_size] ids_str = ",".join(batch_ids) fetch_url = f"{base_url}efetch.fcgi" fetch_params = { "db": "pubmed", "id": ids_str, "retmode": "xml" } try: response = requests.get(fetch_url, params=fetch_params) response.raise_for_status() # Check rate limits rate_limit = int(response.headers.get('X-RateLimit-Limit', '3')) rate_remaining = int(response.headers.get('X-RateLimit-Remaining', '0')) logger.info(f"Rate limit: {rate_remaining}/{rate_limit} requests remaining") if rate_remaining <= 1: logger.warning("Rate limit nearly reached, waiting 60 seconds") time.sleep(60) # Parse XML with proper features soup = BeautifulSoup(response.text, 'lxml', features="xml") except requests.exceptions.RequestException as e: logger.error(f"Failed to fetch PubMed article batch {i//batch_size + 1}: {str(e)}") continue except Exception as e: logger.error(f"Failed to parse PubMed article batch {i//batch_size + 1}: {str(e)}") continue for article in soup.find_all('PubmedArticle'): try: abstract = article.find('Abstract') if abstract: abstract_text = abstract.get_text().strip() if len(abstract_text) > 100: # Filter out very short abstracts title = article.find('ArticleTitle') if not title: continue title_text = title.get_text().strip() pub_date = article.find('PubDate') year = "Unknown" if pub_date and pub_date.find('Year'): year = pub_date.find('Year').get_text().strip() abstracts.append({ "title": title_text, "abstract": abstract_text, "year": year, "source": "pubmed", "query": query }) except Exception as e: logger.debug(f"Failed to process article in batch {i//batch_size + 1}: {str(e)}") continue # Always wait between batches to respect rate limits time.sleep(1) logger.info(f"Collected {len(abstracts)} abstracts for query '{query}'") return abstracts def create_training_dataset(self) -> pd.DataFrame: """Combine all collected data into training dataset""" all_texts = [] # Load all collected datasets for json_file in tqdm(list(self.output_dir.glob("*.json")), desc="Processing collected data"): try: with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) # Extract text content for item in data: text_content = self._extract_text_content(item) if text_content: processed_text = self._clean_text(text_content) if processed_text: all_texts.append({ "text": processed_text, "source": json_file.stem, "length": len(processed_text), "type": self._determine_text_type(processed_text) }) except Exception as e: error_msg = f"Failed to process {json_file}: {str(e)}" logger.error(error_msg) self.stats["errors"].append(error_msg) continue # Create DataFrame df = pd.DataFrame(all_texts) # Basic filtering df = df[df['length'] > 100] # Remove very short texts df = df[df['length'] < 5000] # Remove very long texts # Remove duplicates df = df.drop_duplicates(subset=['text']) # Save processed dataset output_file = self.output_dir.parent / "processed" / "training_data.csv" output_file.parent.mkdir(exist_ok=True) df.to_csv(output_file, index=False, encoding='utf-8') # Update stats self.stats["final_samples"] = len(df) self.stats["text_types"] = df['type'].value_counts().to_dict() logger.info(f"Created training dataset with {len(df)} samples") return df def _process_dataset_item(self, item: Dict) -> Dict: """Process and validate a dataset item""" try: # Extract text content text = self._extract_text_content(item) if not text or len(text) < 100: return None # Clean text cleaned_text = self._clean_text(text) if not cleaned_text: return None # Create processed item processed = { "text": cleaned_text, "source": "huggingface", "type": self._determine_text_type(cleaned_text) } # Add metadata if available for key in ['title', 'question', 'answer', 'instruction']: if key in item: processed[key] = str(item[key]) return processed except Exception: return None def _extract_text_content(self, item: Dict) -> str: """Extract relevant text content from dataset item""" # Common text fields in medical datasets text_fields = ['text', 'content', 'abstract', 'question', 'answer', 'instruction', 'output', 'input', 'context'] for field in text_fields: if field in item and item[field]: return str(item[field]) # Fallback: combine multiple fields combined_text = "" for key, value in item.items(): if isinstance(value, str) and len(value) > 20: combined_text += f"{value} " return combined_text.strip() def _clean_text(self, text: str) -> str: """Clean and normalize text""" if not text: return "" # Remove special characters and normalize whitespace text = re.sub(r'[^\w\s.,;:!?()-]', ' ', text) text = re.sub(r'\s+', ' ', text) # Remove common noise text = re.sub(r'http\S+', '', text) text = re.sub(r'www\S+', '', text) text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) return text.strip() def _determine_text_type(self, text: str) -> str: """Determine the type of medical text""" text = text.lower() if any(term in text for term in ['discharge', 'summary', 'discharge summary']): return 'discharge_summary' elif any(term in text for term in ['lab', 'laboratory', 'test results']): return 'lab_report' elif any(term in text for term in ['prescription', 'medication', 'drug']): return 'prescription' elif any(term in text for term in ['question', 'answer', 'qa']): return 'medical_qa' else: return 'clinical_note' def generate_report(self) -> Dict: """Generate a report of the data collection process""" # Convert all datetime objects to strings for k, v in self.stats.items(): if isinstance(v, datetime): self.stats[k] = str(v) self.stats["end_time"] = str(datetime.now()) if isinstance(self.stats["start_time"], datetime): self.stats["start_time"] = str(self.stats["start_time"]) # Calculate duration as string try: start_dt = datetime.fromisoformat(self.stats["start_time"]) end_dt = datetime.fromisoformat(self.stats["end_time"]) self.stats["duration"] = str(end_dt - start_dt) except Exception: self.stats["duration"] = "unknown" report_file = self.output_dir.parent / "reports" / "collection_report.json" report_file.parent.mkdir(exist_ok=True) with open(report_file, 'w', encoding='utf-8') as f: json.dump(self.stats, f, indent=2, ensure_ascii=False) return self.stats def main(): """Run data collection pipeline""" try: collector = MedicalDataCollector() # Collect from Hugging Face logger.info("Starting Hugging Face dataset collection...") hf_data = collector.collect_huggingface_datasets() # Collect from PubMed logger.info("Starting PubMed collection...") pubmed_data = collector.collect_pubmed_abstracts() # Create training dataset logger.info("Creating training dataset...") training_df = collector.create_training_dataset() # Generate report report = collector.generate_report() # Print summary logger.info("\nData Collection Summary:") logger.info(f"Total samples collected: {report['total_samples']}") logger.info(f"Final training samples: {report['final_samples']}") logger.info(f"Duration: {report['duration']}") logger.info("\nText types distribution:") for type_, count in report['text_types'].items(): logger.info(f"- {type_}: {count}") if report['errors']: logger.warning(f"\nEncountered {len(report['errors'])} errors during collection") except Exception as e: logger.error(f"Data collection failed: {str(e)}", exc_info=True) sys.exit(1) if __name__ == "__main__": main()