Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import librosa | |
import json | |
import os | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from datetime import datetime | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
Wav2Vec2ForCTC, | |
Wav2Vec2Processor, | |
HubertModel, | |
pipeline | |
) | |
from transformers.pipelines.pt_utils import KeyDataset | |
from datasets import Dataset | |
import whisper | |
from scipy.spatial.distance import cosine | |
from phonemizer import phonemize | |
import seaborn as sns | |
from sklearn.manifold import TSNE | |
from sklearn.decomposition import PCA | |
# Create directories for storing user data | |
os.makedirs("user_data", exist_ok=True) | |
os.makedirs("user_data/audio", exist_ok=True) | |
os.makedirs("user_data/plots", exist_ok=True) | |
# ===== MODEL INITIALIZATION ===== | |
# Option 1: Your existing model | |
model_name = "BeastGokul/Nika-1.5B" | |
llm_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
llm_model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Option 2: OpenAI Whisper for speech recognition | |
whisper_processor = AutoProcessor.from_pretrained("openai/whisper-large-v3") | |
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3") | |
# Option 3: Wav2Vec2 for phoneme-level analysis | |
# Automatically use GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# === ASR Model: Wav2Vec2 Large (best for transcription) === | |
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") | |
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") | |
wav2vec_model.to(device).eval() # Set to evaluation mode | |
# === Embedding Model: HuBERT Large (best for pronunciation / embeddings) === | |
hubert_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") | |
hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") | |
hubert_model.to(device).eval() # Set to evaluation mode | |
# System prompt for the LLM | |
SYSTEM_PROMPT = """You are a specialized pronunciation assistant for non-native English speakers. | |
Your job is to provide targeted, actionable feedback based on the user's speech or description. | |
When analyzing pronunciation: | |
1. Identify at most 2 specific phonemes or pronunciation patterns that need improvement | |
2. Explain how the sound is correctly formed (tongue position, lip movement, etc.) | |
3. Suggest one simple, targeted exercise for practice | |
4. Be encouraging and note any improvements from previous sessions | |
5. Use simple language appropriate for language learners | |
When provided with phonetic analysis data, incorporate this information into your feedback. | |
""" | |
# ===== PRONUNCIATION TRACKING FUNCTIONS ===== | |
# Data management | |
def get_user_data_path(user_id="default"): | |
return f"user_data/{user_id}_data.json" | |
def load_user_data(user_id="default"): | |
file_path = get_user_data_path(user_id) | |
if os.path.exists(file_path): | |
with open(file_path, "r") as f: | |
return json.load(f) | |
return { | |
"profile": { | |
"native_language": "", | |
"challenge_sounds": [], | |
"practice_count": 0, | |
"joined_date": datetime.now().strftime("%Y-%m-%d") | |
}, | |
"practice_sessions": [], | |
"phoneme_progress": {}, | |
"word_progress": {}, | |
"goals": [] | |
} | |
def save_user_data(data, user_id="default"): | |
with open(get_user_data_path(user_id), "w") as f: | |
json.dump(data, f, indent=2) | |
def save_audio(audio, user_id="default"): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
file_path = f"user_data/audio/{user_id}_{timestamp}.wav" | |
if isinstance(audio, tuple): | |
sr, y = audio | |
# Convert to mono if needed | |
if len(y.shape) > 1: | |
y = y.mean(axis=1) | |
# Changed from librosa.output.write_wav to soundfile.write | |
import soundfile as sf | |
sf.write(file_path, y, sr) | |
else: | |
# Assuming audio is a file path | |
import shutil | |
shutil.copy(audio, file_path) | |
return file_path | |
# Audio processing and phonetic analysis | |
def transcribe_with_whisper(audio_path): | |
"""Transcribe audio using OpenAI's Whisper model""" | |
result = whisper_model.transcribe(audio_path) | |
return result["text"] | |
def extract_phonemes(text): | |
"""Convert text to phonemes""" | |
return phonemize(text, language='en-us', backend='espeak', strip=True) | |
def analyze_audio_phonetically(audio_path, reference_text=None): | |
"""Perform phonetic analysis of the audio compared to reference text""" | |
# Process audio | |
audio, sr = librosa.load(audio_path, sr=16000) | |
inputs = wav2vec_processor(audio, sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
logits = wav2vec_model(inputs.input_values).logits | |
# Get predicted IDs and convert to phonemes | |
predicted_ids = torch.argmax(logits, dim=-1) | |
phoneme_sequence = wav2vec_processor.batch_decode(predicted_ids)[0] | |
result = { | |
"detected_phonemes": phoneme_sequence, | |
} | |
# If reference text is provided, compare with expected phonemes | |
if reference_text: | |
reference_phonemes = extract_phonemes(reference_text) | |
# Here we would normally use dynamic time warping (DTW) or similar | |
# to align and compare phoneme sequences | |
# For the prototype, we'll use a simplified approach | |
result["reference_phonemes"] = reference_phonemes | |
result["analysis"] = "Phoneme comparison would be performed here" | |
return result | |
def extract_pronunciation_embedding(audio_path): | |
"""Extract pronunciation embedding for comparison purposes""" | |
global hubert_model, hubert_processor | |
# Initialize models if needed | |
if hubert_model is None or hubert_processor is None: | |
hubert_model, hubert_processor = initialize_hubert() | |
# Process audio | |
audio, sr = librosa.load(audio_path, sr=16000) | |
inputs = hubert_processor(audio, sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = hubert_model(**inputs) | |
# Extract embedding (mean over time dimension) | |
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
return embedding | |
def compare_with_native(user_embedding, native_embedding): | |
"""Compare user pronunciation embedding with native speaker embedding""" | |
# Import scipy.spatial.distance here | |
from scipy.spatial.distance import cosine | |
similarity = 1 - cosine(user_embedding.flatten(), native_embedding.flatten()) | |
return similarity | |
# ===== LLM FEEDBACK FUNCTIONS ===== | |
def get_llm_feedback(audio=None, text=None, reference_text=None, user_id="default"): | |
"""Get LLM feedback based on audio or text input""" | |
user_data = load_user_data(user_id) | |
# Process audio if provided | |
if audio: | |
audio_path = save_audio(audio, user_id) | |
# Transcribe if no text was provided | |
if not text: | |
text = transcribe_with_whisper(audio_path) | |
# Get phonetic analysis | |
phonetic_analysis = analyze_audio_phonetically(audio_path, reference_text) | |
phonetic_info = f""" | |
Phonetic analysis: | |
- Detected phonemes: {phonetic_analysis['detected_phonemes']} | |
""" | |
if reference_text: | |
phonetic_info += f"- Reference phonemes: {phonetic_analysis.get('reference_phonemes', 'N/A')}\n" | |
else: | |
audio_path = None | |
phonetic_info = "" | |
# Get user history context | |
history_context = "" | |
if user_data["practice_sessions"]: | |
# Find common challenging phonemes | |
phoneme_counts = {p: data["practice_count"] for p, data in user_data["phoneme_progress"].items()} | |
challenging = sorted(phoneme_counts.items(), key=lambda x: x[1], reverse=True)[:3] | |
history_context = f""" | |
User has practiced {len(user_data['practice_sessions'])} times before. | |
Common challenging phonemes: {', '.join([p for p, _ in challenging])}. | |
""" | |
# Build prompt for LLM | |
if text: | |
user_input = f"I said: '{text}'" | |
if reference_text and reference_text != text: | |
user_input += f". I was trying to say: '{reference_text}'" | |
else: | |
user_input = "Please analyze my pronunciation." | |
full_prompt = f"""{SYSTEM_PROMPT} | |
User history: | |
{history_context} | |
{phonetic_info} | |
User: {user_input} | |
""" | |
# Get LLM response | |
inputs = llm_tokenizer(full_prompt, return_tensors="pt").to(llm_model.device) | |
with torch.no_grad(): | |
outputs = llm_model.generate( | |
**inputs, | |
max_new_tokens=200, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the model's response (after the final "Assistant: ") | |
try: | |
response = response.split("Assistant: ")[-1].strip() | |
except: | |
pass | |
# Track the session if audio was provided | |
if audio_path: | |
track_practice_session(user_id, audio_path, text, reference_text, response) | |
return response, text | |
# This function is duplicated in the original code, keeping only one version | |
def track_practice_session(user_id, audio_path, text, reference_text, feedback): | |
"""Track a practice session and update user progress""" | |
user_data = load_user_data(user_id) | |
# Get phonetic analysis | |
phonetic_analysis = analyze_audio_phonetically(audio_path, reference_text) | |
# Extract embedding for future comparison | |
try: | |
embedding = extract_pronunciation_embedding(audio_path) | |
embedding_path = f"user_data/{user_id}_embedding_{len(user_data['practice_sessions'])}.npy" | |
np.save(embedding_path, embedding) | |
except Exception as e: | |
embedding_path = None | |
print(f"Error extracting embedding: {e}") | |
# Extract phonemes from the text | |
phonemes = extract_phonemes(text) | |
# Update session data | |
session = { | |
"date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"text": text, | |
"reference_text": reference_text if reference_text else text, | |
"audio_path": audio_path, | |
"embedding_path": embedding_path, | |
"phonetic_analysis": phonetic_analysis, | |
"feedback": feedback | |
} | |
user_data["practice_sessions"].append(session) | |
# Update phoneme progress | |
for phoneme in set(phonemes): | |
if phoneme not in user_data["phoneme_progress"]: | |
user_data["phoneme_progress"][phoneme] = { | |
"practice_count": 0, | |
"first_practiced": datetime.now().strftime("%Y-%m-%d"), | |
"confidence_scores": [] | |
} | |
user_data["phoneme_progress"][phoneme]["practice_count"] += 1 | |
user_data["phoneme_progress"][phoneme]["last_practiced"] = datetime.now().strftime("%Y-%m-%d") | |
# In a real implementation, we would compute a confidence score for this phoneme | |
# For now, use a random score that generally improves over time | |
prev_scores = user_data["phoneme_progress"][phoneme]["confidence_scores"] | |
last_score = prev_scores[-1] if prev_scores else 0.5 | |
new_score = min(0.95, last_score + np.random.uniform(-0.1, 0.2)) | |
user_data["phoneme_progress"][phoneme]["confidence_scores"].append(float(new_score)) | |
# Update profile stats | |
user_data["profile"]["practice_count"] += 1 | |
# Save updated data | |
save_user_data(user_data, user_id) | |
return session | |
# ===== PROGRESS REPORTING ===== | |
def generate_progress_report(user_id="default"): | |
"""Generate a comprehensive progress report""" | |
user_data = load_user_data(user_id) | |
if not user_data["practice_sessions"]: | |
return "No practice sessions recorded yet. Start practicing to see your progress!" | |
# Basic stats | |
total_sessions = len(user_data["practice_sessions"]) | |
practice_dates = [session["date"].split()[0] for session in user_data["practice_sessions"]] | |
practice_frequency = len(set(practice_dates)) | |
# Phoneme progress analysis | |
improving_phonemes = [] | |
challenging_phonemes = [] | |
for phoneme, data in user_data["phoneme_progress"].items(): | |
if len(data["confidence_scores"]) >= 3: | |
early_avg = sum(data["confidence_scores"][:2]) / 2 | |
recent_avg = sum(data["confidence_scores"][-2:]) / 2 | |
if recent_avg - early_avg > 0.15: | |
improving_phonemes.append((phoneme, recent_avg - early_avg)) | |
elif recent_avg < 0.6: | |
challenging_phonemes.append((phoneme, recent_avg)) | |
# Sort lists | |
improving_phonemes.sort(key=lambda x: x[1], reverse=True) | |
challenging_phonemes.sort(key=lambda x: x[1]) | |
# Generate plots | |
if total_sessions >= 3: | |
plot_path = generate_progress_plots(user_id) | |
else: | |
plot_path = None | |
# Format report | |
report = f"""# Pronunciation Progress Report | |
## Overview | |
- Total practice sessions: {total_sessions} | |
- Days practiced: {practice_frequency} | |
- Practice streak: {calculate_streak(practice_dates)} days | |
## Progress Highlights | |
""" | |
if improving_phonemes: | |
report += "### Most Improved Sounds\n" | |
for phoneme, improvement in improving_phonemes[:3]: | |
report += f"- {phoneme}: {improvement:.2f} improvement\n" | |
if challenging_phonemes: | |
report += "\n### Sounds to Focus On\n" | |
for phoneme, score in challenging_phonemes[:3]: | |
report += f"- {phoneme}: current score {score:.2f}\n" | |
# Recent sessions summary | |
report += "\n## Recent Sessions\n" | |
for session in user_data["practice_sessions"][-3:]: | |
report += f"- {session['date']}: \"{session['text']}\"\n" | |
return report | |
def calculate_streak(date_strings): | |
"""Calculate the current practice streak in days""" | |
if not date_strings: | |
return 0 | |
# Convert to datetime objects and find unique dates | |
dates = sorted(set([datetime.strptime(d, "%Y-%m-%d") for d in date_strings])) | |
# Check if the most recent date is today or yesterday | |
today = datetime.now().date() | |
most_recent = dates[-1].date() | |
if (today - most_recent).days > 1: | |
return 0 # Streak broken | |
# Count consecutive days backward | |
streak = 1 | |
for i in range(len(dates)-2, -1, -1): | |
if (dates[i+1].date() - dates[i].date()).days == 1: | |
streak += 1 | |
else: | |
break | |
return streak | |
def generate_progress_plots(user_id="default"): | |
"""Generate visualization plots of user progress""" | |
user_data = load_user_data(user_id) | |
# Create a dataframe for easier plotting | |
phoneme_data = [] | |
for phoneme, data in user_data["phoneme_progress"].items(): | |
for i, score in enumerate(data["confidence_scores"]): | |
phoneme_data.append({ | |
"phoneme": phoneme, | |
"session": i + 1, | |
"score": score | |
}) | |
if not phoneme_data: | |
return None | |
df = pd.DataFrame(phoneme_data) | |
# Plot 1: Overall progress for most practiced phonemes | |
plt.figure(figsize=(10, 6)) | |
top_phonemes = df["phoneme"].value_counts().head(5).index.tolist() | |
for phoneme in top_phonemes: | |
phoneme_df = df[df["phoneme"] == phoneme] | |
plt.plot(phoneme_df["session"], phoneme_df["score"], marker='o', label=phoneme) | |
plt.title("Pronunciation Progress for Top Phonemes") | |
plt.xlabel("Practice Session") | |
plt.ylabel("Confidence Score") | |
plt.legend() | |
plt.grid(True, linestyle='--', alpha=0.7) | |
plt.tight_layout() | |
plot_path = f"user_data/plots/{user_id}_progress.png" | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path | |
# ===== GRADIO UI ===== | |
# Define practice exercises | |
PRACTICE_EXERCISES = [ | |
{"title": "Basic Vowels", "text": "The cat sat on the mat."}, | |
{"title": "R Sound", "text": "The red robin ran around the river."}, | |
{"title": "TH Sounds", "text": "I think these three things are worth it."}, | |
{"title": "L vs R", "text": "The light rain falls along the lake."}, | |
{"title": "V vs W", "text": "We very much want to visit the west village."}, | |
{"title": "Short Phrases", "text": "Excuse me. Thank you. I'm sorry. Nice to meet you."} | |
] | |
# Create Gradio app | |
with gr.Blocks(title="ESL Pronunciation Coach - Advanced") as demo: | |
user_id = gr.State("default") | |
gr.Markdown("# 🗣️ Advanced Pronunciation Coach") | |
with gr.Tab("Practice"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Practice options | |
exercise_dropdown = gr.Dropdown( | |
choices=[ex["title"] for ex in PRACTICE_EXERCISES], | |
label="Select Practice Exercise", | |
value=PRACTICE_EXERCISES[0]["title"] | |
) | |
reference_text = gr.Textbox( | |
label="Practice Text (Read This Aloud)", | |
value=PRACTICE_EXERCISES[0]["text"], | |
lines=2 | |
) | |
# Update reference text when dropdown changes | |
def update_reference_text(exercise_title): | |
for ex in PRACTICE_EXERCISES: | |
if ex["title"] == exercise_title: | |
return ex["text"] | |
return "" | |
exercise_dropdown.change(update_reference_text, exercise_dropdown, reference_text) | |
# Audio input | |
audio_input = gr.Audio(label="Record your pronunciation", type="filepath", format="wav", show_label=True) | |
submit_btn = gr.Button("Get Feedback", variant="primary") | |
with gr.Column(scale=3): | |
# Results area | |
transcription_output = gr.Textbox(label="Your Speech (Transcribed)", lines=2) | |
feedback_output = gr.Textbox(label="Pronunciation Feedback", lines=6) | |
# Pronunciation tracker | |
with gr.Accordion("Track Your Progress", open=False): | |
difficulty_slider = gr.Slider( | |
minimum=1, maximum=5, value=3, step=1, | |
label="How difficult was this for you? (1: Easy, 5: Very Difficult)" | |
) | |
notes_input = gr.Textbox( | |
label="Your Notes (optional)", | |
placeholder="Note any specific challenges you faced..." | |
) | |
track_btn = gr.Button("Save to Progress Tracker") | |
with gr.Tab("Progress Tracker"): | |
progress_btn = gr.Button("Generate Progress Report") | |
progress_output = gr.Markdown(label="Your Progress") | |
with gr.Tab("Self Assessment"): | |
gr.Markdown(""" | |
## Self-Assessment Tool | |
Record yourself saying the following text, then compare with a native speaker model. | |
""") | |
assessment_text = gr.Textbox( | |
label="Assessment Text", | |
value="The quick brown fox jumps over the lazy dog.", | |
lines=2 | |
) | |
assessment_audio = gr.Audio(type="filepath", label="Record your pronunciation", format="wav") | |
assess_btn = gr.Button("Analyze Pronunciation") | |
assessment_output = gr.Textbox(label="Pronunciation Analysis", lines=8) | |
with gr.Tab("Settings"): | |
native_language = gr.Dropdown( | |
choices=["English", "Spanish", "Chinese", "Arabic", "Russian", "Hindi", "Japanese", "Korean", "French", "Other"], | |
label="Your Native Language", | |
value="Other" | |
) | |
focus_areas = gr.CheckboxGroup( | |
choices=["Vowel sounds", "Consonant sounds", "Word stress", "Sentence rhythm", "Intonation"], | |
label="Areas to Focus On" | |
) | |
save_settings_btn = gr.Button("Save Settings") | |
settings_output = gr.Textbox(label="Status") | |
# Connect functions | |
def process_audio(audio, ref_text): | |
if not audio: | |
return "No audio recorded", "Please record your pronunciation first." | |
feedback, transcription = get_llm_feedback(audio, None, ref_text) | |
return transcription, feedback | |
submit_btn.click( | |
process_audio, | |
inputs=[audio_input, reference_text], | |
outputs=[transcription_output, feedback_output] | |
) | |
progress_btn.click( | |
generate_progress_report, | |
inputs=[], | |
outputs=[progress_output] | |
) | |
def save_user_settings(language, areas): | |
user_data = load_user_data() | |
user_data["profile"]["native_language"] = language | |
user_data["profile"]["focus_areas"] = areas | |
save_user_data(user_data) | |
return "Settings saved successfully!" | |
save_settings_btn.click( | |
save_user_settings, | |
inputs=[native_language, focus_areas], | |
outputs=[settings_output] | |
) | |
def analyze_pronunciation(audio, text): | |
if not audio: | |
return "No audio recorded. Please record your pronunciation first." | |
# In a real implementation, this would compare with native speaker models | |
# For this prototype, we'll use the LLM for detailed feedback | |
feedback, _ = get_llm_feedback(audio, None, text) | |
return feedback | |
assess_btn.click( | |
analyze_pronunciation, | |
inputs=[assessment_audio, assessment_text], | |
outputs=[assessment_output] | |
) | |
demo.launch() |