geethareddy commited on
Commit
578b499
·
verified ·
1 Parent(s): bca924e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -64
app.py CHANGED
@@ -9,9 +9,14 @@ import soundfile
9
  import torch
10
  from tenacity import retry, stop_after_attempt, wait_fixed
11
  import logging
 
12
 
13
  # Set up logging
14
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
 
 
15
  logger = logging.getLogger(__name__)
16
 
17
  # Initialize local models with retry logic
@@ -24,7 +29,7 @@ def load_whisper_model():
24
  device=-1, # CPU; use device=0 for GPU if available
25
  model_kwargs={"use_safetensors": True}
26
  )
27
- logger.info("Whisper model loaded successfully.")
28
  return model
29
  except Exception as e:
30
  logger.error(f"Failed to load Whisper model: {str(e)}")
@@ -39,18 +44,17 @@ def load_symptom_model():
39
  device=-1, # CPU
40
  model_kwargs={"use_safetensors": True}
41
  )
42
- logger.info("Symptom-2-Disease model loaded successfully.")
43
  return model
44
  except Exception as e:
45
  logger.error(f"Failed to load Symptom-2-Disease model: {str(e)}")
46
- # Fallback to a generic model
47
  try:
48
  model = pipeline(
49
  "text-classification",
50
  model="distilbert-base-uncased",
51
  device=-1
52
  )
53
- logger.warning("Fallback to distilbert-base-uncased model.")
54
  return model
55
  except Exception as fallback_e:
56
  logger.error(f"Fallback model failed: {str(fallback_e)}")
@@ -63,131 +67,147 @@ is_fallback_model = False
63
  try:
64
  whisper = load_whisper_model()
65
  except Exception as e:
66
- logger.error(f"Whisper model initialization failed after retries: {str(e)}")
67
 
68
  try:
69
  symptom_classifier = load_symptom_model()
70
  except Exception as e:
71
- logger.error(f"Symptom model initialization failed after retries: {str(e)}")
72
  symptom_classifier = None
73
  is_fallback_model = True
74
 
75
  def compute_file_hash(file_path):
76
- """Compute MD5 hash of a file to check uniqueness."""
77
- hash_md5 = hashlib.md5()
78
- with open(file_path, "rb") as f:
79
- for chunk in iter(lambda: f.read(4096), b""):
80
- hash_md5.update(chunk)
81
- return hash_md5.hexdigest()
 
 
 
 
82
 
83
  def transcribe_audio(audio_file):
84
- """Transcribe audio using local Whisper model."""
85
  if not whisper:
86
- return "Error: Whisper model not loaded. Check logs for details."
 
87
  try:
88
- # Load and validate audio
89
  audio, sr = librosa.load(audio_file, sr=16000)
90
- if len(audio) < 1600: # Less than 0.1s
91
- return "Error: Audio too short. Provide at least 1 second."
92
- if np.max(np.abs(audio)) < 1e-4: # Too quiet
93
- return "Error: Audio too quiet. Provide clear audio describing symptoms."
 
 
94
 
95
- # Save as WAV for Whisper
96
- temp_wav = f"/tmp/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_temp.wav"
97
- soundfile.write(audio, sr, temp_wav)
 
98
 
99
- # Transcribe with beam search
100
  with torch.no_grad():
101
- result = whisper(temp_wav, generate_kwargs={"num_beams": 5})
102
  transcription = result.get("text", "").strip()
103
  logger.info(f"Transcription: {transcription}")
104
 
105
- # Clean up temp file
106
  try:
107
- os.remove(temp_wav)
108
- logger.info(f"Deleted temp file: {temp_wav}")
109
  except Exception as e:
110
- logger.error(f"Failed to delete temp file: {str(e)}")
111
 
112
  if not transcription:
113
- return "Error: Transcription empty. Provide clear audio describing symptoms."
114
- # Check for repetitive transcription
115
  words = transcription.split()
116
  if len(words) > 5 and len(set(words)) < len(words) / 2:
117
- return "Error: Transcription repetitive. Provide clear, non-repetitive audio."
 
118
  return transcription
119
  except Exception as e:
120
- logger.error(f"Error transcribing audio: {str(e)}")
121
  return f"Error: {str(e)}"
122
 
123
  def analyze_symptoms(text):
124
- """Analyze symptoms using local Symptom-2-Disease model."""
125
  if not symptom_classifier:
126
- return "Error: Symptom-2-Disease model not loaded.", 0.0
 
127
  try:
128
  if not text or "Error" in text:
129
- return "Error: No valid transcription for analysis.", 0.0
 
130
  with torch.no_grad():
131
  result = symptom_classifier(text)
132
  if result and isinstance(result, list) and len(result) > 0:
133
  prediction = result[0]["label"]
134
  score = result[0]["score"]
135
  if is_fallback_model:
136
- logger.warning("Using fallback model; results may be less accurate.")
137
- prediction = f"{prediction} (fallback model)"
138
  logger.info(f"Prediction: {prediction}, Score: {score:.4f}")
139
  return prediction, score
 
140
  return "No health condition detected", 0.0
141
  except Exception as e:
142
- logger.error(f"Error analyzing symptoms: {str(e)}")
143
  return f"Error: {str(e)}", 0.0
144
 
145
  def analyze_voice(audio_file):
146
  """Analyze voice for health indicators."""
147
  try:
148
- # Ensure unique file name
149
- unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}"
 
 
 
 
 
 
 
150
  os.rename(audio_file, unique_path)
151
  audio_file = unique_path
 
152
 
153
- # Log audio file info
154
  file_hash = compute_file_hash(audio_file)
155
- logger.info(f"Processing audio: {audio_file}, Hash: {file_hash}")
156
 
157
- # Load audio to verify
158
  audio, sr = librosa.load(audio_file, sr=16000)
159
- logger.info(f"Audio shape: {audio.shape}, SR: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}")
160
 
161
- # Transcribe audio
162
  transcription = transcribe_audio(audio_file)
163
  if "Error" in transcription:
 
164
  return transcription
165
 
166
- # Check for medication queries
167
  if any(keyword in transcription.lower() for keyword in ["medicine", "treatment"]):
168
- return "Error: This tool does not provide medication or treatment advice."
 
169
 
170
- # Analyze symptoms
171
  prediction, score = analyze_symptoms(transcription)
172
  if "Error" in prediction:
 
173
  return prediction
174
 
175
- # Generate one-line feedback
176
- feedback = "No health condition detected, consult a doctor if symptoms persist." if prediction == "No health condition detected" else f"Possible {prediction.lower()} detected, consult a doctor."
177
-
178
- # Log debug info
179
- logger.info(f"Feedback: {feedback}, Transcription: {transcription}, Prediction: {prediction}, Confidence: {score:.4f}, Hash: {file_hash}")
 
180
 
181
- # Clean up audio file
182
  try:
183
  os.remove(audio_file)
184
- logger.info(f"Deleted audio file: {audio_file}")
185
  except Exception as e:
186
  logger.error(f"Failed to delete audio file: {str(e)}")
187
 
188
  return feedback
189
  except Exception as e:
190
- logger.error(f"Error processing audio: {str(e)}")
191
  return f"Error: {str(e)}"
192
 
193
  def test_with_sample_audio():
@@ -195,28 +215,35 @@ def test_with_sample_audio():
195
  sample_audio_path = "audio_samples/sample.wav"
196
  if not os.path.exists(sample_audio_path):
197
  logger.warning("Sample audio not found; generating synthetic audio")
198
- # Generate synthetic audio (sine wave to simulate voice)
199
  sr = 16000
200
  t = np.linspace(0, 2, 2 * sr)
201
  freq_mod = 440 + 10 * np.sin(2 * np.pi * 0.5 * t)
202
  amplitude_mod = 0.5 + 0.1 * np.sin(2 * np.pi * 0.3 * t)
203
  noise = 0.01 * np.random.normal(0, 1, len(t))
204
  dummy_audio = amplitude_mod * np.sin(2 * np.pi * freq_mod * t) + noise
205
- sample_audio_path = "audio_samples/dummy_test.wav"
206
- os.makedirs("audio_samples", exist_ok=True)
207
  try:
208
  soundfile.write(dummy_audio, sr, sample_audio_path)
209
- logger.info(f"Generated synthetic audio at: {sample_audio_path}")
210
  except Exception as e:
211
  logger.error(f"Failed to write synthetic audio: {str(e)}")
212
  return f"Error: Failed to generate synthetic audio: {str(e)}"
213
 
214
- # Mock transcription for synthetic audio
215
  mock_transcription = "I have a cough and sore throat"
216
- logger.info(f"Mock transcription for synthetic audio: {mock_transcription}")
217
  prediction, score = analyze_symptoms(mock_transcription)
218
- feedback = "No health condition detected, consult a doctor if symptoms persist." if prediction == "No health condition detected" else f"Possible {prediction.lower()} detected, consult a doctor."
 
 
 
 
219
  logger.info(f"Test feedback: {feedback}, Prediction: {prediction}, Score: {score:.4f}")
 
 
 
 
 
220
  return feedback
221
 
222
  # Gradio interface
 
9
  import torch
10
  from tenacity import retry, stop_after_attempt, wait_fixed
11
  import logging
12
+ import tempfile
13
 
14
  # Set up logging
15
+ logging.basicConfig(
16
+ level=logging.DEBUG,
17
+ format="%(asctime)s - %(levelname)s - %(message)s",
18
+ handlers=[logging.FileHandler("voice_analyzer.log"), logging.StreamHandler()]
19
+ )
20
  logger = logging.getLogger(__name__)
21
 
22
  # Initialize local models with retry logic
 
29
  device=-1, # CPU; use device=0 for GPU if available
30
  model_kwargs={"use_safetensors": True}
31
  )
32
+ logger.info("Whisper model loaded successfully")
33
  return model
34
  except Exception as e:
35
  logger.error(f"Failed to load Whisper model: {str(e)}")
 
44
  device=-1, # CPU
45
  model_kwargs={"use_safetensors": True}
46
  )
47
+ logger.info("Symptom-2-Disease model loaded successfully")
48
  return model
49
  except Exception as e:
50
  logger.error(f"Failed to load Symptom-2-Disease model: {str(e)}")
 
51
  try:
52
  model = pipeline(
53
  "text-classification",
54
  model="distilbert-base-uncased",
55
  device=-1
56
  )
57
+ logger.warning("Fallback to distilbert-base-uncased model")
58
  return model
59
  except Exception as fallback_e:
60
  logger.error(f"Fallback model failed: {str(fallback_e)}")
 
67
  try:
68
  whisper = load_whisper_model()
69
  except Exception as e:
70
+ logger.error(f"Whisper model initialization failed: {str(e)}")
71
 
72
  try:
73
  symptom_classifier = load_symptom_model()
74
  except Exception as e:
75
+ logger.error(f"Symptom model initialization failed: {str(e)}")
76
  symptom_classifier = None
77
  is_fallback_model = True
78
 
79
  def compute_file_hash(file_path):
80
+ """Compute MD5 hash of a file."""
81
+ try:
82
+ hash_md5 = hashlib.md5()
83
+ with open(file_path, "rb") as f:
84
+ for chunk in iter(lambda: f.read(4096), b""):
85
+ hash_md5.update(chunk)
86
+ return hash_md5.hexdigest()
87
+ except Exception as e:
88
+ logger.error(f"Failed to compute file hash: {str(e)}")
89
+ return "unknown"
90
 
91
  def transcribe_audio(audio_file):
92
+ """Transcribe audio using Whisper model."""
93
  if not whisper:
94
+ logger.error("Whisper model not loaded")
95
+ return "Error: Whisper model not loaded"
96
  try:
97
+ logger.debug(f"Transcribing audio: {audio_file}")
98
  audio, sr = librosa.load(audio_file, sr=16000)
99
+ if len(audio) < 1600:
100
+ logger.error("Audio too short")
101
+ return "Error: Audio too short (<0.1s)"
102
+ if np.max(np.abs(audio)) < 1e-4:
103
+ logger.error("Audio too quiet")
104
+ return "Error: Audio too quiet"
105
 
106
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
107
+ temp_path = temp_wav.name
108
+ soundfile.write(audio, sr, temp_path)
109
+ logger.debug(f"Saved temp WAV: {temp_path}")
110
 
 
111
  with torch.no_grad():
112
+ result = whisper(temp_path, generate_kwargs={"num_beams": 5})
113
  transcription = result.get("text", "").strip()
114
  logger.info(f"Transcription: {transcription}")
115
 
 
116
  try:
117
+ os.remove(temp_path)
118
+ logger.debug(f"Deleted temp WAV: {temp_path}")
119
  except Exception as e:
120
+ logger.error(f"Failed to delete temp WAV: {str(e)}")
121
 
122
  if not transcription:
123
+ logger.error("Transcription empty")
124
+ return "Error: Transcription empty"
125
  words = transcription.split()
126
  if len(words) > 5 and len(set(words)) < len(words) / 2:
127
+ logger.error("Transcription repetitive")
128
+ return "Error: Transcription repetitive"
129
  return transcription
130
  except Exception as e:
131
+ logger.error(f"Transcription failed: {str(e)}")
132
  return f"Error: {str(e)}"
133
 
134
  def analyze_symptoms(text):
135
+ """Analyze symptoms using Symptom-2-Disease model."""
136
  if not symptom_classifier:
137
+ logger.error("Symptom-2-Disease model not loaded")
138
+ return "Error: Symptom-2-Disease model not loaded", 0.0
139
  try:
140
  if not text or "Error" in text:
141
+ logger.error(f"Invalid transcription: {text}")
142
+ return "Error: No valid transcription", 0.0
143
  with torch.no_grad():
144
  result = symptom_classifier(text)
145
  if result and isinstance(result, list) and len(result) > 0:
146
  prediction = result[0]["label"]
147
  score = result[0]["score"]
148
  if is_fallback_model:
149
+ logger.warning("Using fallback model")
150
+ prediction = f"{prediction} (fallback)"
151
  logger.info(f"Prediction: {prediction}, Score: {score:.4f}")
152
  return prediction, score
153
+ logger.warning("No prediction returned")
154
  return "No health condition detected", 0.0
155
  except Exception as e:
156
+ logger.error(f"Symptom analysis failed: {str(e)}")
157
  return f"Error: {str(e)}", 0.0
158
 
159
  def analyze_voice(audio_file):
160
  """Analyze voice for health indicators."""
161
  try:
162
+ logger.debug(f"Starting analysis for: {audio_file}")
163
+ if not os.path.exists(audio_file):
164
+ logger.error(f"Audio file not found: {audio_file}")
165
+ return "Error: Audio file not found"
166
+
167
+ unique_path = os.path.join(
168
+ tempfile.gettempdir(),
169
+ f"gradio_{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}"
170
+ )
171
  os.rename(audio_file, unique_path)
172
  audio_file = unique_path
173
+ logger.debug(f"Renamed to: {audio_file}")
174
 
 
175
  file_hash = compute_file_hash(audio_file)
176
+ logger.info(f"Processing audio, Hash: {file_hash}")
177
 
 
178
  audio, sr = librosa.load(audio_file, sr=16000)
179
+ logger.info(f"Audio loaded: shape={audio.shape}, SR={sr}, Duration={len(audio)/sr:.2f}s")
180
 
 
181
  transcription = transcribe_audio(audio_file)
182
  if "Error" in transcription:
183
+ logger.error(f"Transcription error: {transcription}")
184
  return transcription
185
 
 
186
  if any(keyword in transcription.lower() for keyword in ["medicine", "treatment"]):
187
+ logger.warning("Medication query detected")
188
+ return "Error: This tool does not provide medication advice"
189
 
 
190
  prediction, score = analyze_symptoms(transcription)
191
  if "Error" in prediction:
192
+ logger.error(f"Symptom analysis error: {prediction}")
193
  return prediction
194
 
195
+ feedback = (
196
+ "No health condition detected, consult a doctor if symptoms persist."
197
+ if prediction == "No health condition detected"
198
+ else f"Possible {prediction.lower()} detected, consult a doctor."
199
+ )
200
+ logger.info(f"Feedback: {feedback}, Transcription: {transcription}, Prediction: {prediction}, Score: {score:.4f}")
201
 
 
202
  try:
203
  os.remove(audio_file)
204
+ logger.debug(f"Deleted audio file: {audio_file}")
205
  except Exception as e:
206
  logger.error(f"Failed to delete audio file: {str(e)}")
207
 
208
  return feedback
209
  except Exception as e:
210
+ logger.error(f"Voice analysis failed: {str(e)}")
211
  return f"Error: {str(e)}"
212
 
213
  def test_with_sample_audio():
 
215
  sample_audio_path = "audio_samples/sample.wav"
216
  if not os.path.exists(sample_audio_path):
217
  logger.warning("Sample audio not found; generating synthetic audio")
 
218
  sr = 16000
219
  t = np.linspace(0, 2, 2 * sr)
220
  freq_mod = 440 + 10 * np.sin(2 * np.pi * 0.5 * t)
221
  amplitude_mod = 0.5 + 0.1 * np.sin(2 * np.pi * 0.3 * t)
222
  noise = 0.01 * np.random.normal(0, 1, len(t))
223
  dummy_audio = amplitude_mod * np.sin(2 * np.pi * freq_mod * t) + noise
224
+ sample_audio_path = os.path.join(tempfile.gettempdir(), "dummy_test.wav")
225
+ os.makedirs(os.path.dirname(sample_audio_path), exist_ok=True)
226
  try:
227
  soundfile.write(dummy_audio, sr, sample_audio_path)
228
+ logger.info(f"Generated synthetic audio: {sample_audio_path}")
229
  except Exception as e:
230
  logger.error(f"Failed to write synthetic audio: {str(e)}")
231
  return f"Error: Failed to generate synthetic audio: {str(e)}"
232
 
 
233
  mock_transcription = "I have a cough and sore throat"
234
+ logger.info(f"Mock transcription: {mock_transcription}")
235
  prediction, score = analyze_symptoms(mock_transcription)
236
+ feedback = (
237
+ "No health condition detected, consult a doctor if symptoms persist."
238
+ if prediction == "No health condition detected"
239
+ else f"Possible {prediction.lower()} detected, consult a doctor."
240
+ )
241
  logger.info(f"Test feedback: {feedback}, Prediction: {prediction}, Score: {score:.4f}")
242
+ try:
243
+ os.remove(sample_audio_path)
244
+ logger.debug(f"Deleted test audio: {sample_audio_path}")
245
+ except Exception:
246
+ pass
247
  return feedback
248
 
249
  # Gradio interface