|
|
import gradio as gr |
|
|
import librosa |
|
|
import numpy as np |
|
|
import os |
|
|
import hashlib |
|
|
from datetime import datetime |
|
|
from transformers import pipeline |
|
|
import soundfile |
|
|
import torch |
|
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
|
import logging |
|
|
import tempfile |
|
|
import shutil |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.DEBUG, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
|
handlers=[logging.FileHandler("voice_analyzer.log"), logging.StreamHandler()] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) |
|
|
def load_whisper_model(): |
|
|
try: |
|
|
model = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model="openai/whisper-tiny.en", |
|
|
device=-1, |
|
|
model_kwargs={"use_safetensors": True} |
|
|
) |
|
|
logger.info("Whisper model loaded successfully") |
|
|
return model |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load Whisper model: {str(e)}") |
|
|
raise |
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) |
|
|
def load_symptom_model(): |
|
|
try: |
|
|
model = pipeline( |
|
|
"text-classification", |
|
|
model="abhirajeshbhai/symptom-2-disease-net", |
|
|
device=-1, |
|
|
model_kwargs={"use_safetensors": True} |
|
|
) |
|
|
logger.info("Symptom-2-Disease model loaded successfully") |
|
|
return model |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load Symptom-2-Disease model: {str(e)}") |
|
|
try: |
|
|
model = pipeline( |
|
|
"text-classification", |
|
|
model="distilbert-base-uncased", |
|
|
device=-1 |
|
|
) |
|
|
logger.warning("Fallback to distilbert-base-uncased model") |
|
|
return model |
|
|
except Exception as fallback_e: |
|
|
logger.error(f"Fallback model failed: {str(fallback_e)}") |
|
|
raise |
|
|
|
|
|
whisper = None |
|
|
symptom_classifier = None |
|
|
is_fallback_model = False |
|
|
|
|
|
try: |
|
|
whisper = load_whisper_model() |
|
|
except Exception as e: |
|
|
logger.error(f"Whisper model initialization failed: {str(e)}") |
|
|
|
|
|
try: |
|
|
symptom_classifier = load_symptom_model() |
|
|
except Exception as e: |
|
|
logger.error(f"Symptom model initialization failed: {str(e)}") |
|
|
symptom_classifier = None |
|
|
is_fallback_model = True |
|
|
|
|
|
def compute_file_hash(file_path): |
|
|
"""Compute MD5 hash of a file.""" |
|
|
try: |
|
|
hash_md5 = hashlib.md5() |
|
|
with open(file_path, "rb") as f: |
|
|
for chunk in iter(lambda: f.read(4096), b""): |
|
|
hash_md5.update(chunk) |
|
|
return hash_md5.hexdigest() |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to compute file hash: {str(e)}") |
|
|
return "unknown" |
|
|
|
|
|
def ensure_writable_dir(directory): |
|
|
"""Ensure directory exists and is writable.""" |
|
|
try: |
|
|
os.makedirs(directory, exist_ok=True) |
|
|
test_file = os.path.join(directory, "test_write") |
|
|
with open(test_file, "w") as f: |
|
|
f.write("test") |
|
|
os.remove(test_file) |
|
|
logger.debug(f"Directory {directory} is writable") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Directory {directory} not writable: {str(e)}") |
|
|
return False |
|
|
|
|
|
def transcribe_audio(audio_file): |
|
|
"""Transcribe audio using Whisper model.""" |
|
|
if not whisper: |
|
|
logger.error("Whisper model not loaded") |
|
|
return "Error: Whisper model not loaded" |
|
|
try: |
|
|
logger.debug(f"Transcribing audio: {audio_file}") |
|
|
if not os.path.exists(audio_file): |
|
|
logger.error(f"Audio file not found: {audio_file}") |
|
|
return "Error: Audio file not found" |
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
|
if len(audio) < 1600: |
|
|
logger.error("Audio too short") |
|
|
return "Error: Audio too short (<0.1s)" |
|
|
if np.max(np.abs(audio)) < 1e-4: |
|
|
logger.error("Audio too quiet") |
|
|
return "Error: Audio too quiet" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: |
|
|
temp_path = temp_wav.name |
|
|
soundfile.write(audio, sr, temp_path) |
|
|
logger.debug(f"Saved temp WAV: {temp_path}") |
|
|
|
|
|
with torch.no_grad(): |
|
|
result = whisper(temp_path, generate_kwargs={"num_beams": 5}) |
|
|
transcription = result.get("text", "").strip() |
|
|
logger.info(f"Transcription: {transcription}") |
|
|
|
|
|
try: |
|
|
os.remove(temp_path) |
|
|
logger.debug(f"Deleted temp WAV: {temp_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to delete temp WAV: {str(e)}") |
|
|
|
|
|
if not transcription: |
|
|
logger.error("Transcription empty") |
|
|
return "Error: Transcription empty" |
|
|
words = transcription.split() |
|
|
if len(words) > 5 and len(set(words)) < len(words) / 2: |
|
|
logger.error("Transcription repetitive") |
|
|
return "Error: Transcription repetitive" |
|
|
return transcription |
|
|
except Exception as e: |
|
|
logger.error(f"Transcription failed: {str(e)}") |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def analyze_symptoms(text): |
|
|
"""Analyze symptoms using Symptom-2-Disease model.""" |
|
|
if not symptom_classifier: |
|
|
logger.error("Symptom-2-Disease model not loaded") |
|
|
return "Error: Symptom-2-Disease model not loaded", 0.0 |
|
|
try: |
|
|
if not text or "Error" in text: |
|
|
logger.error(f"Invalid transcription: {text}") |
|
|
return "Error: No valid transcription", 0.0 |
|
|
with torch.no_grad(): |
|
|
result = symptom_classifier(text) |
|
|
logger.debug(f"Model output: {result}") |
|
|
if not result or not isinstance(result, list) or len(result) == 0: |
|
|
logger.warning("Invalid model output: empty or not a list") |
|
|
return "No health condition detected", 0.0 |
|
|
if not isinstance(result[0], dict) or "label" not in result[0] or "score" not in result[0]: |
|
|
logger.warning(f"Invalid result structure: {result[0]}") |
|
|
return "No health condition detected", 0.0 |
|
|
prediction = result[0]["label"] |
|
|
score = result[0]["score"] |
|
|
if is_fallback_model: |
|
|
logger.warning("Using fallback model") |
|
|
prediction = f"{prediction} (fallback)" |
|
|
logger.info(f"Prediction: {prediction}, Score: {score:.4f}") |
|
|
return prediction, score |
|
|
except Exception as e: |
|
|
logger.error(f"Symptom analysis failed: {str(e)}") |
|
|
return f"Error: {str(e)}", 0.0 |
|
|
|
|
|
def analyze_voice(audio_file): |
|
|
"""Analyze voice for health indicators.""" |
|
|
try: |
|
|
logger.debug(f"Starting analysis for: {audio_file}") |
|
|
if not os.path.exists(audio_file): |
|
|
logger.error(f"Audio file not found: {audio_file}") |
|
|
return "Error: Audio file not found" |
|
|
|
|
|
temp_dir = os.path.join(tempfile.gettempdir(), "gradio") |
|
|
if not ensure_writable_dir(temp_dir): |
|
|
logger.error(f"Temp directory {temp_dir} not writable") |
|
|
return f"Error: Temp directory {temp_dir} not writable" |
|
|
|
|
|
unique_path = os.path.join( |
|
|
temp_dir, |
|
|
f"gradio_{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}" |
|
|
) |
|
|
try: |
|
|
shutil.move(audio_file, unique_path) |
|
|
audio_file = unique_path |
|
|
logger.debug(f"Moved to: {audio_file}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to move audio file: {str(e)}") |
|
|
return f"Error: Failed to move audio file: {str(e)}" |
|
|
|
|
|
file_hash = compute_file_hash(audio_file) |
|
|
logger.info(f"Processing audio, Hash: {file_hash}") |
|
|
|
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
|
logger.info(f"Audio loaded: shape={audio.shape}, SR={sr}, Duration={len(audio)/sr:.2f}s") |
|
|
|
|
|
transcription = transcribe_audio(audio_file) |
|
|
if "Error" in transcription: |
|
|
logger.error(f"Transcription error: {transcription}") |
|
|
return transcription |
|
|
|
|
|
if any(keyword in transcription.lower() for keyword in ["medicine", "treatment"]): |
|
|
logger.warning("Medication query detected") |
|
|
return "Error: This tool does not provide medication advice" |
|
|
|
|
|
prediction, score = analyze_symptoms(transcription) |
|
|
if "Error" in prediction: |
|
|
logger.error(f"Symptom analysis error: {prediction}") |
|
|
return prediction |
|
|
|
|
|
feedback = ( |
|
|
"No health condition detected, consult a doctor if symptoms persist." |
|
|
if prediction == "No health condition detected" |
|
|
else f"Possible {prediction.lower()} detected, consult a doctor." |
|
|
) |
|
|
logger.info(f"Feedback: {feedback}, Transcription: {transcription}, Prediction: {prediction}, Score: {score:.4f}") |
|
|
|
|
|
try: |
|
|
os.remove(audio_file) |
|
|
logger.debug(f"Deleted audio file: {audio_file}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to delete audio file: {str(e)}") |
|
|
|
|
|
return feedback |
|
|
except Exception as e: |
|
|
logger.error(f"Voice analysis failed: {str(e)}") |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def test_with_sample_audio(): |
|
|
"""Test with sample or synthetic audio.""" |
|
|
temp_dir = os.path.join(tempfile.gettempdir(), "audio_samples") |
|
|
if not ensure_writable_dir(temp_dir): |
|
|
logger.error(f"Temp directory {temp_dir} not writable") |
|
|
return f"Error: Temp directory {temp_dir} not writable" |
|
|
|
|
|
sample_audio_path = os.path.join(temp_dir, "sample.wav") |
|
|
if not os.path.exists(sample_audio_path): |
|
|
logger.warning("Sample audio not found; generating synthetic audio") |
|
|
sr = 16000 |
|
|
t = np.linspace(0, 2, 2 * sr) |
|
|
freq_mod = 440 + 10 * np.sin(2 * np.pi * 0.5 * t) |
|
|
amplitude_mod = 0.5 + 0.1 * np.sin(2 * np.pi * 0.3 * t) |
|
|
noise = 0.01 * np.random.normal(0, 1, len(t)) |
|
|
dummy_audio = amplitude_mod * np.sin(2 * np.pi * freq_mod * t) + noise |
|
|
sample_audio_path = os.path.join(temp_dir, "dummy_test.wav") |
|
|
try: |
|
|
soundfile.write(dummy_audio, sr, sample_audio_path) |
|
|
logger.info(f"Generated synthetic audio: {sample_audio_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to write synthetic audio: {str(e)}") |
|
|
return f"Error: Failed to generate synthetic audio: {str(e)}" |
|
|
|
|
|
if not os.path.exists(sample_audio_path): |
|
|
logger.error(f"Synthetic audio not created: {sample_audio_path}") |
|
|
return f"Error: Synthetic audio not created: {sample_audio_path}" |
|
|
|
|
|
mock_transcription = "I have a cough and sore throat" |
|
|
logger.info(f"Mock transcription: {mock_transcription}") |
|
|
prediction, score = analyze_symptoms(mock_transcription) |
|
|
feedback = ( |
|
|
"No health condition detected, consult a doctor if symptoms persist." |
|
|
if prediction == "No health condition detected" |
|
|
else f"Possible {prediction.lower()} detected, consult a doctor." |
|
|
) |
|
|
logger.info(f"Test feedback: {feedback}, Prediction: {prediction}, Score: {score:.4f}") |
|
|
try: |
|
|
os.remove(sample_audio_path) |
|
|
logger.debug(f"Deleted test audio: {sample_audio_path}") |
|
|
except Exception: |
|
|
pass |
|
|
return feedback |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=analyze_voice, |
|
|
inputs=gr.Audio(type="filepath", label="Record or Upload Voice (WAV, MP3, FLAC, 1+ sec)"), |
|
|
outputs=gr.Textbox(label="Health Assessment Feedback"), |
|
|
title="Voice Health Analyzer", |
|
|
description="Record or upload a voice sample describing symptoms (e.g., 'I have a cough') for preliminary health assessment. Supports English only. Use clear audio (WAV, 16kHz). Do not ask for medication advice." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
logger.info("Starting Voice Health Analyzer") |
|
|
print(test_with_sample_audio()) |
|
|
iface.launch(server_name="0.0.0.0", server_port=7860) |