llama-omni / predict.py
marcosremar2's picture
dfdfdf
34b8b49
import os
import time
import subprocess
import whisper
from cog import BasePredictor, Input, Path
import torch
import tempfile
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make inference faster"""
print("Loading models...")
# Load whisper for audio transcription
print("Loading Whisper model...")
self.whisper_model = whisper.load_model("large-v3", download_root="models/speech_encoder/")
# In a real implementation, this would load the LLaMA-Omni model
print("Note: In a real deployment, the LLaMA-Omni model would be loaded here")
# Start the controller
print("Starting controller...")
self.controller_process = subprocess.Popen([
"python", "-m", "omni_speech.serve.controller",
"--host", "0.0.0.0",
"--port", "10000"
])
time.sleep(5) # Wait for controller to start
# Start model worker
print("Starting model worker...")
self.model_worker_process = subprocess.Popen([
"python", "-m", "omni_speech.serve.model_worker",
"--host", "0.0.0.0",
"--controller", "http://localhost:10000",
"--port", "40000",
"--worker", "http://localhost:40000",
"--model-path", "Llama-3.1-8B-Omni",
"--model-name", "Llama-3.1-8B-Omni",
"--s2s"
])
time.sleep(10) # Wait for model worker to start
print("Setup complete")
def predict(
self,
audio: Path = Input(description="Audio file for speech input", default=None),
text: str = Input(description="Text input (used if no audio is provided)", default=None),
) -> str:
"""Run inference on the model"""
if audio is None and not text:
return "Error: Please provide either an audio file or text input."
if audio is not None:
# Process audio input
print(f"Transcribing audio from {audio}...")
# Transcribe audio using Whisper
result = self.whisper_model.transcribe(str(audio))
transcription = result["text"]
print(f"Transcription: {transcription}")
# In a real implementation, this would process the transcription through LLaMA-Omni
# For this placeholder, we'll just return the transcription with a simulated response
response = f"Transcription: {transcription}\n\nResponse: This is a simulated response to your audio. In a real deployment, this would be processed through the LLaMA-Omni model."
return response
else:
# Process text input
print(f"Processing text: {text}")
# In a real implementation, this would process the text through LLaMA-Omni
# For this placeholder, we'll just return the text with a simulated response
response = f"Input: {text}\n\nResponse: This is a simulated response to your text. In a real deployment, this would be processed through the LLaMA-Omni model."
return response
def __del__(self):
"""Clean up processes on shutdown"""
if hasattr(self, 'controller_process'):
self.controller_process.terminate()
if hasattr(self, 'model_worker_process'):
self.model_worker_process.terminate()