import gradio as gr import mne import numpy as np import pandas as pd from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os model_name = "tiiuae/falcon-7b-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto" ) def compute_band_power(psd, freqs, fmin, fmax): freq_mask = (freqs >= fmin) & (freqs <= fmax) band_psd = psd[:, freq_mask].mean() return float(band_psd) def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'): """ Load EEG data from a file with flexible CSV handling. - If FIF: Use read_raw_fif. - If CSV: * If `time_col` is present, use it as time. * Otherwise, assume a default sfreq and treat all columns as channels. """ _, file_ext = os.path.splitext(file_path) file_ext = file_ext.lower() if file_ext == '.fif': raw = mne.io.read_raw_fif(file_path, preload=True) elif file_ext == '.csv': df = pd.read_csv(file_path) # Remove non-numeric columns except time_col for col in df.columns: if col != time_col: # Drop non-numeric columns if any if not pd.api.types.is_numeric_dtype(df[col]): df = df.drop(columns=[col]) if time_col in df.columns: # Use the provided time column time = df[time_col].values data_df = df.drop(columns=[time_col]) if len(time) < 2: raise ValueError("Not enough time points to estimate sampling frequency.") sfreq = 1.0 / np.mean(np.diff(time)) else: # No explicit time column, assume uniform sampling at default_sfreq sfreq = default_sfreq data_df = df # Channels are all remaining columns ch_names = list(data_df.columns) data = data_df.values.T # shape: (n_channels, n_samples) # Create MNE info ch_types = ['eeg'] * len(ch_names) info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) raw = mne.io.RawArray(data, info) else: raise ValueError("Unsupported file format. Please provide a FIF or CSV file.") return raw def process_eeg(file, default_sfreq, time_col): raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col) psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40) alpha_power = compute_band_power(psd, freqs, 8, 12) beta_power = compute_band_power(psd, freqs, 13, 30) data_summary = ( f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. " f"The EEG shows stable alpha rhythms and slightly elevated beta activity." ) prompt = f"""You are a neuroscientist analyzing EEG features. Data Summary: {data_summary} Provide a concise, user-friendly interpretation of these findings in simple terms. """ inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) outputs = model.generate( inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95 ) summary = tokenizer.decode(outputs[0], skip_special_tokens=True) return summary iface = gr.Interface( fn=process_eeg, inputs=[ gr.File(label="Upload your EEG data (FIF or CSV)"), gr.Textbox(label="Default Sampling Frequency if no time column (Hz)", value="256"), gr.Textbox(label="Time column name (if exists)", value="time") ], outputs="text", title="NeuroNarrative-Lite: EEG Summary (Flexible CSV Handling)", description=( "Upload EEG data in FIF or CSV format. " "If CSV, either include a 'time' column or specify a default sampling frequency. " "Non-numeric columns will be removed (except the chosen time column)." ) ) if __name__ == "__main__": iface.launch()