ParulPandey commited on
Commit
2ec3bec
Β·
verified Β·
1 Parent(s): d8c8d0c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +83 -152
agent.py CHANGED
@@ -13,8 +13,9 @@ import json
13
  import time
14
  import numpy as np
15
  from pathlib import Path
16
- from typing import Dict, List, Optional, Tuple
17
  from dotenv import load_dotenv
 
18
 
19
  # Load environment variables
20
  load_dotenv()
@@ -42,27 +43,27 @@ def generate_story(name: str, grade: str, topic: str) -> str:
42
  str: Generated story text.
43
  """
44
  # Extract grade number and determine age/reading level
45
- grade_num = int(''.join(filter(str.isdigit, grade)) or "3")
46
  age = grade_num + 5 # Grade 1 = ~6 years old, Grade 6 = ~11 years old
47
 
48
  # Dynamically determine story parameters based on grade
49
  if grade_num <= 2:
50
  # Grades 1-2: Very simple stories
51
- story_length = "2-3 short sentences"
52
  vocabulary_level = "very simple words (mostly 1-2 syllables)"
53
  sentence_structure = "short, simple sentences"
54
  complexity = "basic concepts"
55
  reading_level = "beginner"
56
  elif grade_num <= 4:
57
  # Grades 3-4: Intermediate stories
58
- story_length = "1-2 short paragraphs"
59
  vocabulary_level = "age-appropriate words with some longer words"
60
  sentence_structure = "mix of simple and compound sentences"
61
  complexity = "intermediate concepts with some detail"
62
  reading_level = "intermediate"
63
  else:
64
- # Grades 5-6: More advanced stories
65
- story_length = "2-3 paragraphs"
66
  vocabulary_level = "varied vocabulary including descriptive words"
67
  sentence_structure = "complex sentences with descriptive language"
68
  complexity = "detailed concepts and explanations"
@@ -79,15 +80,17 @@ def generate_story(name: str, grade: str, topic: str) -> str:
79
  - Vocabulary: Use {vocabulary_level}
80
  - Sentence structure: {sentence_structure}
81
  - Complexity: {complexity}
82
- - Include {name} as the main character
83
  - Teach something interesting about {topic}
84
  - End with a positive, encouraging message
85
  - Make it engaging and fun to read aloud
 
 
86
 
87
  Additional Guidelines:
88
  - For younger students (Grades 1-2): Focus on simple actions, basic emotions, and clear cause-and-effect
89
- - For middle students (Grades 3-4): Include some problem-solving, friendship themes, and basic science/nature facts
90
- - For older students (Grades 5-6): Add character development, more detailed explanations, and encourage curiosity
91
 
92
  The story should be perfectly suited for a {grade} student's reading ability and attention span.
93
 
@@ -95,7 +98,7 @@ def generate_story(name: str, grade: str, topic: str) -> str:
95
  """
96
 
97
  # Use Google Gemini
98
- model = genai.GenerativeModel('gemini-1.5-flash')
99
 
100
  # Adjust generation parameters based on grade level
101
  max_tokens = 300 if grade_num <= 2 else 600 if grade_num <= 4 else 1000
@@ -163,151 +166,69 @@ def text_to_speech(text: str) -> str:
163
  traceback.print_exc()
164
  return None
165
 
166
- @tool
167
 
168
 
169
- def transcribe_audio(audio_input: str) -> str:
 
170
  """
171
- Transcribe the student's audio into text via Whisper STT service.
172
- Using abidlabs/whisper-large-v2 Hugging Face Space API.
173
 
174
  Args:
175
- audio_input: Either a file path (str) or tuple (sample_rate, numpy_array) from Gradio
176
 
177
  Returns:
178
- str: Transcribed speech text.
179
  """
 
 
 
 
180
  try:
181
- print(f"Received audio input: {type(audio_input)}")
182
-
183
- # Handle different input formats
184
- if isinstance(audio_input, tuple) and len(audio_input) == 2:
185
- # Gradio microphone format: (sample_rate, numpy_array)
186
- sample_rate, audio_data = audio_input
187
- print(f"Audio tuple: sample_rate={sample_rate}, data_shape={audio_data.shape}")
188
- # Pass the tuple directly to the STT service
189
- audio_for_stt = audio_input
190
- elif isinstance(audio_input, (str, Path)):
191
- audio_for_stt = str(audio_input)
192
- else:
193
- print(f"Unsupported audio input type: {type(audio_input)}")
194
- return "Error: Unsupported audio format. Please try recording again."
195
-
196
- if isinstance(audio_for_stt, Path):
197
- audio_for_stt = str(audio_for_stt)
198
-
199
- # Initialize client with error handling
200
- print("Initializing Gradio client for STT...")
201
- try:
202
- client = Client("abidlabs/whisper-large-v2")
203
- except Exception as client_error:
204
- print(f"Failed to initialize client: {client_error}")
205
- # Try alternative approach
206
- try:
207
- print("Trying direct API approach...")
208
- return "Error: STT service initialization failed. Please try again."
209
- except Exception as fallback_error:
210
- print(f"Fallback also failed: {fallback_error}")
211
- return "Error: Speech recognition service unavailable. Please try again later."
212
-
213
- print("Sending audio for transcription...")
214
-
215
- # Make the API call with timeout and error handling
216
- try:
217
- if isinstance(audio_for_stt, tuple):
218
- result = client.predict(audio_for_stt, api_name="/predict")
219
- else:
220
- result = client.predict(audio_for_stt, api_name="/predict")
221
- except Exception as api_error:
222
- print(f"API call failed: {api_error}")
223
- if "extra_headers" in str(api_error):
224
- return "Error: Connection protocol mismatch. Please try recording again."
225
- elif "connection" in str(api_error).lower():
226
- return "Error: Network connection issue. Please check your internet and try again."
227
- else:
228
- return "Error: Transcription service temporarily unavailable. Please try again."
229
-
230
- print(f"Raw transcription result: {result}")
231
- print(f"Result type: {type(result)}")
232
-
233
- # Handle different result types more robustly
234
- if result is None:
235
- return "Error: No transcription result. Please try speaking more clearly and loudly."
236
-
237
- # Extract text from result
238
- transcribed_text = ""
239
-
240
- if isinstance(result, str):
241
- transcribed_text = result.strip()
242
- elif isinstance(result, (list, tuple)):
243
- if len(result) > 0:
244
- # Try to find the text in the result structure
245
- transcribed_text = str(result[0]).strip()
246
- print(f"Extracted from list/tuple: {transcribed_text}")
247
- else:
248
- return "Error: Empty transcription result. Please try again."
249
- elif isinstance(result, dict):
250
- # Handle dictionary results - try common keys
251
- transcribed_text = result.get('text', result.get('transcription', str(result))).strip()
252
- print(f"Extracted from dict: {transcribed_text}")
253
- else:
254
- transcribed_text = str(result).strip()
255
- print(f"Converted to string: {transcribed_text}")
256
-
257
- # Clean up common API artifacts
258
- transcribed_text = transcribed_text.replace('```', '').replace('json', '').replace('{', '').replace('}', '')
259
-
260
- # Validate the transcription
261
- if not transcribed_text or (isinstance(transcribed_text, str) and transcribed_text.lower() in ['', 'none', 'null', 'error', 'undefined']):
262
- return "I couldn't hear any speech clearly. Please try recording again and speak more loudly."
263
-
264
- # Ensure transcribed_text is a string before further processing
265
- if not isinstance(transcribed_text, str):
266
- return "I couldn't hear any speech clearly. Please try recording again and speak more loudly."
267
-
268
- # Check for common error messages from the API
269
- error_indicators = ['error', 'failed', 'could not', 'unable to', 'timeout']
270
- if any(indicator in transcribed_text.lower() for indicator in error_indicators):
271
- return "Transcription service had an issue. Please try recording again."
272
-
273
- # Clean up the transcribed text
274
- transcribed_text = transcribed_text.replace('\n', ' ').replace('\t', ' ')
275
- # Remove extra whitespace
276
- transcribed_text = ' '.join(transcribed_text.split())
277
-
278
- if len(transcribed_text) < 3:
279
- return "The recording was too short or unclear. Please try reading more slowly and clearly."
280
-
281
- print(f"Final transcribed text: {transcribed_text}")
282
- return transcribed_text
283
-
284
- except ImportError as e:
285
- print(f"Import error: {str(e)}")
286
- return "Error: Missing required libraries. Please check your installation."
287
-
288
- except ConnectionError as e:
289
- print(f"Connection error: {str(e)}")
290
- return "Network connection error. Please check your internet connection and try again."
291
-
292
- except TimeoutError as e:
293
- print(f"Timeout error: {str(e)}")
294
- return "Transcription service is taking too long. Please try again with a shorter recording."
295
-
296
  except Exception as e:
297
- print(f"Unexpected transcription error: {str(e)}")
298
- error_msg = str(e).lower()
299
-
300
- # Provide helpful error messages based on the error type
301
- if "timeout" in error_msg or "connection" in error_msg:
302
- return "Network timeout. Please check your internet connection and try again."
303
- elif "file" in error_msg or "path" in error_msg:
304
- return "Audio file error. Please try recording again."
305
- elif "api" in error_msg or "client" in error_msg or "gradio" in error_msg:
306
- return "Transcription service temporarily unavailable. Please try again in a moment."
307
- elif "memory" in error_msg or "size" in error_msg:
308
- return "Audio file is too large or complex. Please try with a shorter recording."
309
- else:
310
- return f"Transcription failed. Please try recording again. If the problem persists, try speaking more clearly."
311
 
312
  def compare_texts_for_feedback(original: str, spoken: str) -> str:
313
  """
@@ -327,7 +248,7 @@ def compare_texts_for_feedback(original: str, spoken: str) -> str:
327
 
328
  # Calculate accuracy using sequence matching
329
  matcher = SequenceMatcher(None, orig_words, spoken_words, autojunk=False)
330
- accuracy = matcher.ratio() * 100
331
 
332
  # Identify different types of errors
333
  missed_words = set(orig_words) - set(spoken_words)
@@ -361,7 +282,7 @@ def find_similar_words(original_words: list, spoken_words: list) -> list:
361
 
362
  return mispronounced[:5]
363
 
364
- def generate_adaptive_feedback(accuracy: float, missed_words: set, extra_words: set,
365
  mispronounced: list, total_words: int) -> str:
366
  """
367
  Generate age-appropriate, encouraging feedback with specific learning guidance.
@@ -522,10 +443,13 @@ def generate_targeted_story(previous_feedback: str, name: str, grade: str, misse
522
  age = grade_num + 5
523
 
524
  # Extract difficulty level from previous feedback
525
- if "AMAZING" in previous_feedback or "accuracy: 9" in previous_feedback:
 
 
 
526
  difficulty_adjustment = "slightly more challenging"
527
  focus_area = "new vocabulary and longer sentences"
528
- elif "GOOD" in previous_feedback or "accuracy: 8" in previous_feedback:
529
  difficulty_adjustment = "similar level with some new words"
530
  focus_area = "reinforcing current skills"
531
  else:
@@ -561,7 +485,7 @@ def generate_targeted_story(previous_feedback: str, name: str, grade: str, misse
561
  """
562
 
563
  # Generate targeted story
564
- model = genai.GenerativeModel('gemini-1.5-flash')
565
  max_tokens = 300 if grade_num <= 2 else 600 if grade_num <= 4 else 1000
566
 
567
  generation_config = {
@@ -683,8 +607,15 @@ class ReadingCoachAgent:
683
  name = self.student_info["name"]
684
  grade = self.student_info["grade"]
685
 
686
- # Generate a new practice story using the targeted story function
687
- practice_story = generate_targeted_story("", name, grade)
 
 
 
 
 
 
 
688
  self.current_story = practice_story
689
 
690
  return practice_story
 
13
  import time
14
  import numpy as np
15
  from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple, Union
17
  from dotenv import load_dotenv
18
+ import base64
19
 
20
  # Load environment variables
21
  load_dotenv()
 
43
  str: Generated story text.
44
  """
45
  # Extract grade number and determine age/reading level
46
+ grade_num = int(''.join(filter(str.isdigit, grade)) or "1")
47
  age = grade_num + 5 # Grade 1 = ~6 years old, Grade 6 = ~11 years old
48
 
49
  # Dynamically determine story parameters based on grade
50
  if grade_num <= 2:
51
  # Grades 1-2: Very simple stories
52
+ story_length = "5 short sentences"
53
  vocabulary_level = "very simple words (mostly 1-2 syllables)"
54
  sentence_structure = "short, simple sentences"
55
  complexity = "basic concepts"
56
  reading_level = "beginner"
57
  elif grade_num <= 4:
58
  # Grades 3-4: Intermediate stories
59
+ story_length = "1 short paragraphs"
60
  vocabulary_level = "age-appropriate words with some longer words"
61
  sentence_structure = "mix of simple and compound sentences"
62
  complexity = "intermediate concepts with some detail"
63
  reading_level = "intermediate"
64
  else:
65
+ # Grades 5-10: More advanced stories
66
+ story_length = "2 paragraphs"
67
  vocabulary_level = "varied vocabulary including descriptive words"
68
  sentence_structure = "complex sentences with descriptive language"
69
  complexity = "detailed concepts and explanations"
 
80
  - Vocabulary: Use {vocabulary_level}
81
  - Sentence structure: {sentence_structure}
82
  - Complexity: {complexity}
83
+
84
  - Teach something interesting about {topic}
85
  - End with a positive, encouraging message
86
  - Make it engaging and fun to read aloud
87
+ - start directly with the story, no preamble or introduction
88
+
89
 
90
  Additional Guidelines:
91
  - For younger students (Grades 1-2): Focus on simple actions, basic emotions, and clear cause-and-effect
92
+ - For middle students (Grades 3-5): Include some problem-solving, friendship themes, and basic science/nature facts
93
+ - For older students (Grades 6-10): Add character development, more detailed explanations, and encourage curiosity
94
 
95
  The story should be perfectly suited for a {grade} student's reading ability and attention span.
96
 
 
98
  """
99
 
100
  # Use Google Gemini
101
+ model = genai.GenerativeModel('gemini-2.0-flash')
102
 
103
  # Adjust generation parameters based on grade level
104
  max_tokens = 300 if grade_num <= 2 else 600 if grade_num <= 4 else 1000
 
166
  traceback.print_exc()
167
  return None
168
 
 
169
 
170
 
171
+ @tool
172
+ def transcribe_audio(audio_path: str) -> str:
173
  """
174
+ Transcribe the student's audio into text using Hugging Face Whisper Space.
 
175
 
176
  Args:
177
+ audio_path (str): Path to the recorded .wav audio file
178
 
179
  Returns:
180
+ str: Transcribed text from the audio
181
  """
182
+ import base64
183
+ import requests
184
+ from pathlib import Path
185
+
186
  try:
187
+ print(f"Received audio input: {type(audio_path)} - {str(audio_path)[:100]}...")
188
+
189
+ # Make sure it's a valid file path
190
+ path = Path(audio_path)
191
+ if not path.exists():
192
+ return "Audio file not found. Please try recording again."
193
+
194
+ # Encode audio to base64
195
+ with open(path, "rb") as f:
196
+ encoded = base64.b64encode(f.read()).decode("utf-8")
197
+
198
+ # Prepare payload for HF Space
199
+ payload = {
200
+ "data": [
201
+ {
202
+ "name": path.name,
203
+ "data": f"data:audio/wav;base64,{encoded}"
204
+ },
205
+ None
206
+ ]
207
+ }
208
+
209
+ print("Sending audio to HF STT...")
210
+ response = requests.post(
211
+ "https://abidlabs-whisper-large-v2.hf.space/run/predict",
212
+ json=payload,
213
+ timeout=60
214
+ )
215
+ response.raise_for_status()
216
+
217
+ result = response.json().get("data", [None])[0]
218
+ print(f"HF response: {result}")
219
+
220
+ if not result or not isinstance(result, str) or len(result.strip()) == 0:
221
+ return "Could not transcribe audio. Please speak more clearly and try again."
222
+
223
+ return result.strip()
224
+
225
+ except requests.exceptions.HTTPError as e:
226
+ print(f"HTTP error: {e}")
227
+ return "Transcription service returned an error. Please try again later."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  except Exception as e:
229
+ print(f"Unexpected error: {e}")
230
+ return "Something went wrong during transcription. Please try again."
231
+
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  def compare_texts_for_feedback(original: str, spoken: str) -> str:
234
  """
 
248
 
249
  # Calculate accuracy using sequence matching
250
  matcher = SequenceMatcher(None, orig_words, spoken_words, autojunk=False)
251
+ accuracy = min(round(matcher.quick_ratio() * 100 + 60), 100)
252
 
253
  # Identify different types of errors
254
  missed_words = set(orig_words) - set(spoken_words)
 
282
 
283
  return mispronounced[:5]
284
 
285
+ def generate_adaptive_feedback(accuracy: int, missed_words: set, extra_words: set,
286
  mispronounced: list, total_words: int) -> str:
287
  """
288
  Generate age-appropriate, encouraging feedback with specific learning guidance.
 
443
  age = grade_num + 5
444
 
445
  # Extract difficulty level from previous feedback
446
+ if "AMAZING" in previous_feedback or "accuracy: 9" in previous_feedback or "🌟 AMAZING" in previous_feedback:
447
+ difficulty_adjustment = "more challenging with advanced vocabulary"
448
+ focus_area = "new vocabulary, longer sentences, and complex concepts"
449
+ elif "GREAT JOB" in previous_feedback or "accuracy: 8" in previous_feedback or "πŸŽ‰ GREAT JOB" in previous_feedback:
450
  difficulty_adjustment = "slightly more challenging"
451
  focus_area = "new vocabulary and longer sentences"
452
+ elif "GOOD" in previous_feedback or "accuracy: 7" in previous_feedback or "πŸ‘ GOOD WORK" in previous_feedback:
453
  difficulty_adjustment = "similar level with some new words"
454
  focus_area = "reinforcing current skills"
455
  else:
 
485
  """
486
 
487
  # Generate targeted story
488
+ model = genai.GenerativeModel('gemini-2.0-flash')
489
  max_tokens = 300 if grade_num <= 2 else 600 if grade_num <= 4 else 1000
490
 
491
  generation_config = {
 
607
  name = self.student_info["name"]
608
  grade = self.student_info["grade"]
609
 
610
+ # Get the last feedback from session if available
611
+ last_feedback = ""
612
+ if self.current_session and self.current_session in self.session_manager.sessions:
613
+ session_data = self.session_manager.sessions[self.current_session]
614
+ if session_data.get("feedback_history"):
615
+ last_feedback = session_data["feedback_history"][-1].get("feedback", "")
616
+
617
+ # Generate a new practice story using the targeted story function with feedback context
618
+ practice_story = generate_targeted_story(last_feedback, name, grade)
619
  self.current_story = practice_story
620
 
621
  return practice_story