|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
SpeechT5Processor, |
|
SpeechT5ForTextToSpeech, |
|
SpeechT5HifiGan, |
|
WhisperProcessor, |
|
WhisperForConditionalGeneration |
|
) |
|
from datasets import load_dataset |
|
import os |
|
import spaces |
|
import tempfile |
|
import soundfile as sf |
|
import librosa |
|
|
|
|
|
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd" |
|
TORCH_DTYPE = torch.bfloat16 |
|
MAX_NEW_TOKENS = 512 |
|
DO_SAMPLE = True |
|
TEMPERATURE = 0.7 |
|
TOP_K = 50 |
|
TOP_P = 0.95 |
|
|
|
|
|
TTS_MODEL_ID = "microsoft/speecht5_tts" |
|
TTS_VOCODER_ID = "microsoft/speecht5_hifigan" |
|
|
|
|
|
STT_MODEL_ID = "openai/whisper-tiny" |
|
|
|
|
|
tokenizer = None |
|
llm_model = None |
|
tts_processor = None |
|
tts_model = None |
|
tts_vocoder = None |
|
speaker_embeddings = None |
|
whisper_processor = None |
|
whisper_model = None |
|
|
|
|
|
@spaces.GPU |
|
def load_models(): |
|
""" |
|
Loads the language model, tokenizer, TTS models, speaker embeddings, |
|
and STT (Whisper) models from Hugging Face Hub. |
|
This function will be called once when the Gradio app starts up. |
|
""" |
|
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings |
|
global whisper_processor, whisper_model |
|
|
|
if (tokenizer is not None and llm_model is not None and tts_model is not None and |
|
whisper_processor is not None and whisper_model is not None): |
|
print("All models and tokenizers/processors already loaded.") |
|
return |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})") |
|
|
|
print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...") |
|
llm_model = AutoModelForCausalLM.from_pretrained( |
|
HUGGINGFACE_MODEL_ID, |
|
torch_dtype=TORCH_DTYPE, |
|
device_map="auto", |
|
token=hf_token |
|
) |
|
llm_model.eval() |
|
print("LLM model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading LLM model or tokenizer: {e}") |
|
raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.") |
|
|
|
|
|
print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}") |
|
try: |
|
tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token) |
|
tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token) |
|
tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token) |
|
|
|
print("Loading speaker embeddings for TTS...") |
|
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token) |
|
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) |
|
|
|
device = llm_model.device if llm_model else 'cpu' |
|
tts_model.to(device) |
|
tts_vocoder.to(device) |
|
speaker_embeddings = speaker_embeddings.to(device) |
|
print(f"TTS models and speaker embeddings loaded successfully to device: {device}.") |
|
|
|
except Exception as e: |
|
print(f"Error loading TTS models or speaker embeddings: {e}") |
|
tts_processor = None |
|
tts_model = None |
|
tts_vocoder = None |
|
speaker_embeddings = None |
|
raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.") |
|
|
|
|
|
print(f"Loading STT (Whisper) processor and model from: {STT_MODEL_ID}") |
|
try: |
|
whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token) |
|
whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token) |
|
|
|
device = llm_model.device if llm_model else 'cpu' |
|
whisper_model.to(device) |
|
print(f"STT (Whisper) model loaded successfully to device: {device}.") |
|
except Exception as e: |
|
print(f"Error loading STT (Whisper) model or processor: {e}") |
|
whisper_processor = None |
|
whisper_model = None |
|
raise RuntimeError("Failed to load STT (Whisper) components. Check model ID and internet connection.") |
|
|
|
|
|
|
|
@spaces.GPU |
|
def generate_response_and_audio( |
|
message: str, |
|
history: list |
|
) -> tuple: |
|
""" |
|
Generates a text response from the loaded LLM and then converts it to audio |
|
using the loaded TTS model. |
|
""" |
|
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings |
|
|
|
|
|
if tokenizer is None or llm_model is None or tts_model is None: |
|
load_models() |
|
|
|
if tokenizer is None or llm_model is None: |
|
history.append({"role": "user", "content": message}) |
|
history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."}) |
|
return history, None |
|
|
|
|
|
messages = history |
|
messages.append({"role": "user", "content": message}) |
|
|
|
try: |
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
except Exception as e: |
|
print(f"Error applying chat template: {e}") |
|
input_text = "" |
|
for item in history: |
|
if item["role"] == "user": |
|
input_text += f"User: {item['content']}\n" |
|
elif item["role"] == "assistant": |
|
input_text += f"Assistant: {item['content']}\n" |
|
input_text += f"User: {message}\nAssistant:" |
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device) |
|
|
|
with torch.no_grad(): |
|
output_ids = llm_model.generate( |
|
input_ids, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
do_sample=DO_SAMPLE, |
|
temperature=TEMPERATURE, |
|
top_k=TOP_K, |
|
top_p=TOP_P, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
generated_token_ids = output_ids[0][input_ids.shape[-1]:] |
|
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
audio_path = None |
|
if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None: |
|
try: |
|
device = llm_model.device if llm_model else 'cpu' |
|
tts_model.to(device) |
|
tts_vocoder.to(device) |
|
speaker_embeddings = speaker_embeddings.to(device) |
|
|
|
tts_inputs = tts_processor( |
|
text=generated_text, |
|
return_tensors="pt", |
|
max_length=550, |
|
truncation=True |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
|
audio_path = tmp_file.name |
|
sf.write(audio_path, speech.cpu().numpy(), samplerate=16000) |
|
print(f"Audio saved to: {audio_path}") |
|
|
|
except Exception as e: |
|
print(f"Error generating audio: {e}") |
|
audio_path = None |
|
else: |
|
print("TTS components not loaded. Skipping audio generation.") |
|
|
|
|
|
history.append({"role": "assistant", "content": generated_text}) |
|
|
|
return history, audio_path |
|
|
|
|
|
|
|
@spaces.GPU |
|
def transcribe_audio(audio_filepath): |
|
""" |
|
Transcribes an audio file using the loaded Whisper model. |
|
Handles audio files of varying lengths. |
|
""" |
|
global whisper_processor, whisper_model |
|
|
|
if whisper_processor is None or whisper_model is None: |
|
load_models() |
|
|
|
if whisper_processor is None or whisper_model is None: |
|
return "Error: Speech-to-Text model not loaded. Please check logs." |
|
|
|
if audio_filepath is None: |
|
return "No audio input provided for transcription." |
|
|
|
print(f"Transcribing audio from: {audio_filepath}") |
|
try: |
|
|
|
audio, sample_rate = librosa.load(audio_filepath, sr=16000) |
|
|
|
|
|
input_features = whisper_processor( |
|
audio, |
|
sampling_rate=sample_rate, |
|
return_tensors="pt" |
|
).input_features.to(whisper_model.device) |
|
|
|
|
|
predicted_ids = whisper_model.generate(input_features) |
|
|
|
|
|
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
print(f"Transcription: {transcription}") |
|
return transcription |
|
|
|
except Exception as e: |
|
print(f"Error during transcription: {e}") |
|
return f"Transcription failed: {e}" |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot with Voice Input & Output |
|
Type your message or speak into the microphone to chat with the model. |
|
The chatbot's response will be spoken, and your audio input can be transcribed! |
|
""" |
|
) |
|
|
|
with gr.Tab("Chat with Voice"): |
|
chatbot = gr.Chatbot(label="Conversation", type='messages') |
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Your message", |
|
placeholder="Type your message here...", |
|
scale=4 |
|
) |
|
submit_button = gr.Button("Send", scale=1) |
|
|
|
audio_output = gr.Audio( |
|
label="Listen to Response", |
|
autoplay=True, |
|
interactive=False |
|
) |
|
|
|
submit_button.click( |
|
fn=generate_response_and_audio, |
|
inputs=[text_input, chatbot], |
|
outputs=[chatbot, audio_output], |
|
queue=True |
|
) |
|
text_input.submit( |
|
fn=generate_response_and_audio, |
|
inputs=[text_input, chatbot], |
|
outputs=[chatbot, audio_output], |
|
queue=True |
|
) |
|
|
|
with gr.Tab("Audio Transcription"): |
|
stt_audio_input = gr.Audio( |
|
type="filepath", |
|
label="Upload Audio or Record from Microphone", |
|
|
|
format="wav" |
|
) |
|
transcribe_button = gr.Button("Transcribe Audio") |
|
transcribed_text_output = gr.Textbox( |
|
label="Transcription", |
|
placeholder="Transcription will appear here...", |
|
interactive=False |
|
) |
|
transcribe_button.click( |
|
fn=transcribe_audio, |
|
inputs=[stt_audio_input], |
|
outputs=[transcribed_text_output], |
|
queue=True |
|
) |
|
|
|
|
|
def clear_all(): |
|
return [], "", None, None, "" |
|
clear_button = gr.Button("Clear All") |
|
clear_button.click( |
|
clear_all, |
|
inputs=None, |
|
outputs=[chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output] |
|
) |
|
|
|
|
|
load_models() |
|
|
|
|
|
demo.queue().launch() |
|
|