handler / handler.py
walterheart's picture
Update handler.py
31fdbaa verified
import os
import io
import base64
import torch
import numpy as np
from transformers import BarkModel, BarkProcessor
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler for Bark text-to-speech model.
Args:
path (str, optional): Path to the model directory. Defaults to "".
"""
self.path = path
self.model = None
self.processor = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.initialized = False
def setup(self, **kwargs):
"""
Load the model and processor.
Args:
**kwargs: Additional arguments.
"""
# Load model from the local directory
self.model = BarkModel.from_pretrained(self.path)
self.model.to(self.device)
# Load processor
self.processor = BarkProcessor.from_pretrained(self.path)
self.initialized = True
print(f"Bark model loaded on {self.device}")
def preprocess(self, request: Dict) -> Dict:
"""
Process the input request before inference.
Args:
request (Dict): The request data containing text to convert to speech.
Returns:
Dict: Processed inputs for the model.
"""
if not self.initialized:
self.setup()
inputs = {}
# Get text from the request
if "inputs" in request:
if isinstance(request["inputs"], str):
# Single text input
inputs["text"] = request["inputs"]
elif isinstance(request["inputs"], list):
# List of text inputs
inputs["text"] = request["inputs"][0] # Take the first text
# Get optional parameters
params = request.get("parameters", {})
# Speaker ID/voice preset
if "speaker_id" in params:
inputs["speaker_id"] = params["speaker_id"]
elif "voice_preset" in params:
inputs["voice_preset"] = params["voice_preset"]
# Other generation parameters
if "temperature" in params:
inputs["temperature"] = params.get("temperature", 0.7)
return inputs
def inference(self, inputs: Dict) -> Dict:
"""
Run model inference on the processed inputs.
Args:
inputs (Dict): Processed inputs for the model.
Returns:
Dict: Model outputs.
"""
text = inputs.get("text", "")
if not text:
return {"error": "No text provided for speech generation"}
# Extract optional parameters
speaker_id = inputs.get("speaker_id", None)
voice_preset = inputs.get("voice_preset", None)
temperature = inputs.get("temperature", 0.7)
# Prepare inputs for the model
input_ids = self.processor(text).to(self.device)
# Generate speech
with torch.no_grad():
if speaker_id:
# Use speaker_id if provided
speech_output = self.model.generate(
input_ids=input_ids,
speaker_id=speaker_id,
temperature=temperature
)
elif voice_preset:
# Use voice_preset if provided
speech_output = self.model.generate(
input_ids=input_ids,
voice_preset=voice_preset,
temperature=temperature
)
else:
# Use default settings
speech_output = self.model.generate(
input_ids=input_ids,
temperature=temperature
)
# Convert to numpy array
audio_array = speech_output.cpu().numpy().squeeze()
return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate}
def postprocess(self, inference_output: Dict) -> Dict:
"""
Process the model outputs after inference.
Args:
inference_output (Dict): Model outputs.
Returns:
Dict: Processed outputs ready for the response.
"""
if "error" in inference_output:
return {"error": inference_output["error"]}
audio_array = inference_output.get("audio_array")
sample_rate = inference_output.get("sample_rate", 24000)
# Convert audio array to WAV format
try:
import scipy.io.wavfile as wav
audio_buffer = io.BytesIO()
wav.write(audio_buffer, sample_rate, audio_array)
audio_buffer.seek(0)
audio_data = audio_buffer.read()
# Encode audio data to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
return {
"audio": audio_base64,
"sample_rate": sample_rate,
"format": "wav"
}
except Exception as e:
return {"error": f"Error converting audio: {str(e)}"}
def __call__(self, data: Dict) -> Dict:
"""
Main entry point for the handler.
Args:
data (Dict): Request data.
Returns:
Dict: Response data.
"""
# Ensure the model is initialized
if not self.initialized:
self.setup()
# Process the request
try:
inputs = self.preprocess(data)
outputs = self.inference(inputs)
response = self.postprocess(outputs)
return response
except Exception as e:
return {"error": f"Error processing request: {str(e)}"}