|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, AutoModelForCausalLM, |
|
SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, |
|
WhisperProcessor, WhisperForConditionalGeneration |
|
) |
|
from datasets import load_dataset |
|
import os |
|
import spaces |
|
import tempfile |
|
import soundfile as sf |
|
import librosa |
|
import yaml |
|
|
|
|
|
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd" |
|
TORCH_DTYPE = torch.bfloat16 |
|
MAX_NEW_TOKENS = 512 |
|
DO_SAMPLE = True |
|
TEMPERATURE = 0.7 |
|
TOP_K = 50 |
|
TOP_P = 0.95 |
|
|
|
TTS_MODEL_ID = "microsoft/speecht5_tts" |
|
TTS_VOCODER_ID = "microsoft/speecht5_hifigan" |
|
STT_MODEL_ID = "openai/whisper-small" |
|
|
|
|
|
tokenizer = None |
|
llm_model = None |
|
tts_processor = None |
|
tts_model = None |
|
tts_vocoder = None |
|
speaker_embeddings = None |
|
whisper_processor = None |
|
whisper_model = None |
|
first_load = True |
|
|
|
|
|
def generate_pretty_html(data): |
|
html = """ |
|
<div class="font-sans max-w-xl mx-auto bg-gray-800 text-white rounded-lg p-6 shadow-md"> |
|
<h2 class="text-xl font-semibold text-white border-b border-gray-600 pb-2 mb-4">Model Info</h2> |
|
""" |
|
for key, value in data.items(): |
|
html += f""" |
|
<div class="mb-3"> |
|
<strong class="text-blue-400 inline-block w-40">{key}:</strong> |
|
<span class="text-gray-300">{value}</span> |
|
</div> |
|
""" |
|
html += "</div>" |
|
return html |
|
|
|
def load_config(): |
|
with open("config.yaml", "r", encoding="utf-8") as f: |
|
return yaml.safe_load(f) |
|
|
|
|
|
def render_modern_info(): |
|
try: |
|
config = load_config() |
|
return generate_pretty_html(config) |
|
except Exception as e: |
|
return f"<div style='color: red;'>Error loading config: {str(e)}</div>" |
|
|
|
|
|
def load_readme(): |
|
with open("README.md", "r", encoding="utf-8") as f: |
|
return f.read() |
|
|
|
|
|
|
|
def split_text_into_chunks(text, max_chars=400): |
|
sentences = text.replace("...", ".").split(". ") |
|
chunks = [] |
|
current_chunk = "" |
|
for sentence in sentences: |
|
if len(current_chunk) + len(sentence) + 2 < max_chars: |
|
current_chunk += ". " + sentence if current_chunk else sentence |
|
else: |
|
chunks.append(current_chunk) |
|
current_chunk = sentence |
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
return [f"{chunk}." for chunk in chunks if chunk.strip()] |
|
|
|
|
|
|
|
@spaces.GPU |
|
def load_models(): |
|
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings, whisper_processor, whisper_model |
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
if tokenizer is None or llm_model is None: |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
llm_model = AutoModelForCausalLM.from_pretrained( |
|
HUGGINGFACE_MODEL_ID, |
|
torch_dtype=TORCH_DTYPE, |
|
device_map="auto", |
|
token=hf_token |
|
).eval() |
|
print("LLM loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading LLM: {e}") |
|
|
|
|
|
if tts_processor is None or tts_model is None or tts_vocoder is None: |
|
try: |
|
tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token) |
|
tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token) |
|
tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token) |
|
embeddings = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token) |
|
speaker_embeddings = torch.tensor(embeddings[7306]["xvector"]).unsqueeze(0) |
|
device = llm_model.device if llm_model else 'cpu' |
|
tts_model.to(device) |
|
tts_vocoder.to(device) |
|
speaker_embeddings = speaker_embeddings.to(device) |
|
print("TTS models loaded.") |
|
except Exception as e: |
|
print(f"Error loading TTS: {e}") |
|
|
|
|
|
if whisper_processor is None or whisper_model is None: |
|
try: |
|
whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token) |
|
whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token) |
|
device = llm_model.device if llm_model else 'cpu' |
|
whisper_model.to(device) |
|
print("Whisper loaded.") |
|
except Exception as e: |
|
print(f"Error loading Whisper: {e}") |
|
|
|
|
|
|
|
@spaces.GPU |
|
def generate_response_and_audio(message, history): |
|
global first_load |
|
if first_load: |
|
load_models() |
|
first_load = False |
|
|
|
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings |
|
|
|
if tokenizer is None or llm_model is None: |
|
return [{"role": "assistant", "content": "Error: LLM not loaded."}], None |
|
|
|
messages = history.copy() |
|
messages.append({"role": "user", "content": message}) |
|
|
|
try: |
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
except: |
|
input_text = "" |
|
for item in history: |
|
input_text += f"{item['role'].capitalize()}: {item['content']}\n" |
|
input_text += f"User: {message}\nAssistant:" |
|
|
|
try: |
|
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device) |
|
output_ids = llm_model.generate( |
|
inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
do_sample=DO_SAMPLE, |
|
temperature=TEMPERATURE, |
|
top_k=TOP_K, |
|
top_p=TOP_P, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
generated_text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip() |
|
except Exception as e: |
|
print(f"LLM error: {e}") |
|
return history + [{"role": "assistant", "content": "I had an issue generating a response."}], None |
|
|
|
audio_path = None |
|
if None not in [tts_processor, tts_model, tts_vocoder, speaker_embeddings]: |
|
try: |
|
device = llm_model.device |
|
text_chunks = split_text_into_chunks(generated_text) |
|
|
|
full_speech = [] |
|
for chunk in text_chunks: |
|
tts_inputs = tts_processor(text=chunk, return_tensors="pt", max_length=512, truncation=True).to(device) |
|
speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder) |
|
full_speech.append(speech.cpu()) |
|
|
|
full_speech_tensor = torch.cat(full_speech, dim=0) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
|
audio_path = tmp_file.name |
|
sf.write(audio_path, full_speech_tensor.numpy(), samplerate=16000) |
|
|
|
except Exception as e: |
|
print(f"TTS error: {e}") |
|
|
|
return history + [{"role": "assistant", "content": generated_text}], audio_path |
|
|
|
|
|
@spaces.GPU |
|
def transcribe_audio(filepath): |
|
global first_load |
|
if first_load: |
|
load_models() |
|
first_load = False |
|
|
|
global whisper_processor, whisper_model |
|
if whisper_model is None: |
|
return "Whisper model not loaded." |
|
|
|
try: |
|
audio, sr = librosa.load(filepath, sr=16000) |
|
inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features.to(whisper_model.device) |
|
outputs = whisper_model.generate(inputs) |
|
return whisper_processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
except Exception as e: |
|
return f"Transcription failed: {e}" |
|
|
|
|
|
|
|
with gr.Blocks(head=""" |
|
<script src="https://cdn.tailwindcss.com "></script> |
|
""") as demo: |
|
gr.Markdown(""" |
|
<div class="bg-gray-900 text-white p-4 rounded-lg shadow-md mb-6"> |
|
<h1 class="text-2xl font-bold">Qwen2.5 Chatbot with Voice Input/Output</h1> |
|
<p class="text-gray-300">Powered by Gradio + TailwindCSS</p> |
|
</div> |
|
""") |
|
|
|
with gr.Tab("Chat"): |
|
gr.HTML(""" |
|
<div class="bg-gray-800 p-4 rounded-lg mb-4"> |
|
<label class="block text-gray-300 font-medium mb-2">Chat Interface</label> |
|
</div> |
|
""") |
|
chatbot = gr.Chatbot(type='messages', elem_classes=["bg-gray-800", "text-white"]) |
|
text_input = gr.Textbox( |
|
placeholder="Type your message...", |
|
label="User Input", |
|
elem_classes=["bg-gray-700", "text-white", "border-gray-600"] |
|
) |
|
audio_output = gr.Audio(label="Response Audio", autoplay=True) |
|
text_input.submit(generate_response_and_audio, [text_input, chatbot], [chatbot, audio_output]) |
|
|
|
with gr.Tab("Transcribe"): |
|
gr.HTML(""" |
|
<div class="bg-gray-800 p-4 rounded-lg mb-4"> |
|
<label class="block text-gray-300 font-medium mb-2">Audio Transcription</label> |
|
</div> |
|
""") |
|
audio_input = gr.Audio(type="filepath", label="Upload Audio") |
|
transcribed = gr.Textbox( |
|
label="Transcription", |
|
elem_classes=["bg-gray-700", "text-white", "border-gray-600"] |
|
) |
|
audio_input.upload(transcribe_audio, audio_input, transcribed) |
|
|
|
clear_btn = gr.Button("Clear All", elem_classes=["bg-gray-600", "hover:bg-gray-500", "text-white", "mt-4"]) |
|
clear_btn.click(lambda: ([], "", None), None, [chatbot, text_input, audio_output]) |
|
|
|
html_output = gr.HTML(""" |
|
<div class="bg-gray-800 text-white p-4 rounded-lg mt-6 text-center"> |
|
Loading model info... |
|
</div> |
|
""") |
|
demo.load(fn=render_modern_info, outputs=html_output) |
|
|
|
|
|
|
|
demo.queue().launch() |