Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import os | |
import sys | |
import subprocess | |
import spaces | |
from pathlib import Path | |
# Clone and setup the repository | |
def setup_environment(): | |
if not os.path.exists('LLaMA-Omni'): | |
subprocess.run(['git', 'clone', 'https://github.com/ictnlp/LLaMA-Omni']) | |
# Add to path | |
sys.path.append(os.path.join(os.path.dirname(__file__), 'LLaMA-Omni')) | |
# Download models | |
os.makedirs('models/speech_encoder', exist_ok=True) | |
os.makedirs('vocoder', exist_ok=True) | |
# Download Whisper | |
if not os.path.exists('models/speech_encoder/large-v3.pt'): | |
import whisper | |
whisper.load_model("large-v3", download_root="models/speech_encoder/") | |
# Download vocoder | |
if not os.path.exists('vocoder/g_00500000'): | |
subprocess.run([ | |
'wget', '-q', | |
'https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000', | |
'-P', 'vocoder/' | |
]) | |
subprocess.run([ | |
'wget', '-q', | |
'https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json', | |
'-P', 'vocoder/' | |
]) | |
# Global variables for model | |
model = None | |
speech_generator = None | |
def load_models(): | |
global model, speech_generator | |
if model is None: | |
setup_environment() | |
from omni_speech.model import OmniSpeechModel | |
from omni_speech.speech_generator import SpeechGeneratorCausalFull | |
# Load model | |
model_path = "ICTNLP/Llama-3.1-8B-Omni" | |
model = OmniSpeechModel.from_pretrained(model_path, torch_dtype=torch.float16) | |
model = model.cuda() | |
# Initialize speech generator | |
speech_generator = SpeechGeneratorCausalFull( | |
model=model, | |
vocoder='vocoder/g_00500000', | |
vocoder_cfg='vocoder/config.json' | |
) | |
def process_audio(audio_path, text_input=None): | |
"""Process audio input and generate text and speech response.""" | |
# Load models if needed | |
load_models() | |
from omni_speech.conversation import conv_templates | |
from omni_speech.utils import build_transform_audios | |
# Load and preprocess audio | |
transform = build_transform_audios() | |
audio_tensor = transform(audio_path) | |
# Prepare conversation | |
conv = conv_templates["llama_3"].copy() | |
if text_input: | |
conv.append_message(conv.roles[0], text_input) | |
else: | |
conv.append_message(conv.roles[0], "<Audio>") | |
conv.append_message(conv.roles[1], None) | |
# Generate response | |
with torch.no_grad(): | |
# Get text response | |
text_output = model.generate( | |
audio_tensor.unsqueeze(0).cuda(), | |
conv.get_prompt(), | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True | |
) | |
# Generate speech response | |
speech_output = speech_generator.generate( | |
audio_tensor.unsqueeze(0).cuda(), | |
text_output | |
) | |
return text_output, speech_output | |
# Create Gradio interface | |
with gr.Blocks(title="LLaMA-Omni: Speech-Language Model") as demo: | |
gr.Markdown(""" | |
# π¦π§ LLaMA-Omni: Seamless Speech Interaction | |
Upload an audio file or record your voice to interact with LLaMA-Omni. | |
The model will generate both text and speech responses. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Speech Input" | |
) | |
text_input = gr.Textbox( | |
label="Text Input (Optional)", | |
placeholder="You can also provide text context..." | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
text_output = gr.Textbox( | |
label="Text Response", | |
lines=5 | |
) | |
audio_output = gr.Audio( | |
label="Speech Response", | |
type="filepath" | |
) | |
# Handle submission | |
submit_btn.click( | |
fn=process_audio, | |
inputs=[audio_input, text_input], | |
outputs=[text_output, audio_output] | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
["examples/example1.wav", ""], | |
["examples/example2.wav", "Please explain in detail"], | |
], | |
inputs=[audio_input, text_input], | |
outputs=[text_output, audio_output], | |
fn=process_audio, | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |