ArabicLAWLLM / app.py
ghostai1's picture
Update app.py
8579576 verified
raw
history blame
5.06 kB
import gradio as gr
import torch
import logging
from transformers import AutoTokenizer, AutoModel
from diffusers import DiffusionPipeline
import soundfile as sf
import numpy as np
# Set up logging to debug startup issues
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
# Load text tokenizer and embedding model (umt5-base)
def load_text_processor():
logger.info("Loading text processor (umt5-base)...")
tokenizer = AutoTokenizer.from_pretrained("./umt5-base")
text_model = AutoModel.from_pretrained(
"./umt5-base",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("Text processor loaded successfully.")
return tokenizer, text_model
# Load the transformer backbone (phantomstep_transformer)
def load_transformer():
logger.info("Loading transformer (phantomstep_transformer)...")
transformer = DiffusionPipeline.from_pretrained(
"./phantomstep_transformer",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("Transformer loaded successfully.")
return transformer
# Load the DCAE for audio encoding/decoding (phantomstep_dcae)
def load_dcae():
logger.info("Loading DCAE (phantomstep_dcae)...")
dcae = DiffusionPipeline.from_pretrained(
"./phantomstep_dcae",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("DCAE loaded successfully.")
return dcae
# Load the vocoder for audio synthesis (phantomstep_vocoder)
def load_vocoder():
logger.info("Loading vocoder (phantomstep_vocoder)...")
vocoder = DiffusionPipeline.from_pretrained(
"./phantomstep_vocoder",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("Vocoder loaded successfully.")
return vocoder
# Generate music from a text prompt
def generate_music(prompt, duration=20, seed=42):
logger.info(f"Generating music with prompt: {prompt}, duration: {duration}, seed: {seed}")
torch.manual_seed(seed)
# Load all components
tokenizer, text_model = load_text_processor()
transformer = load_transformer()
dcae = load_dcae()
vocoder = load_vocoder()
# Step 1: Process text prompt to embeddings
logger.info("Processing text prompt to embeddings...")
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(text_model.device) for k, v in inputs.items()}
with torch.no_grad():
embeddings = text_model(**inputs).last_hidden_state.mean(dim=1)
# Step 2: Pass embeddings through transformer
logger.info("Generating with transformer...")
transformer_output = transformer(
embeddings,
num_inference_steps=50,
audio_length_in_s=duration
).audios[0]
# Step 3: Decode audio features with DCAE
logger.info("Decoding with DCAE...")
dcae_output = dcae(
transformer_output,
num_inference_steps=50,
audio_length_in_s=duration
).audios[0]
# Step 4: Synthesize final audio with vocoder
logger.info("Synthesizing with vocoder...")
audio = vocoder(
dcae_output,
num_inference_steps=50,
audio_length_in_s=duration
).audios[0]
# Save audio to a file
output_path = "output.wav"
sf.write(output_path, audio, 22050) # 22kHz sample rate
logger.info("Music generation complete.")
return output_path
# Gradio interface
logger.info("Setting up Gradio interface...")
with gr.Blocks(title="PhantomStep: Text-to-Music Generation 🎡") as demo:
gr.Markdown("# PhantomStep by GhostAI πŸš€")
gr.Markdown("Enter a text prompt to generate music! 🎢")
prompt_input = gr.Textbox(label="Text Prompt", placeholder="A jazzy piano melody with a fast tempo")
duration_input = gr.Slider(label="Duration (seconds)", minimum=10, maximum=60, value=20, step=1)
seed_input = gr.Number(label="Random Seed", value=42, precision=0)
generate_button = gr.Button("Generate Music")
audio_output = gr.Audio(label="Generated Music")
generate_button.click(
fn=generate_music,
inputs=[prompt_input, duration_input, seed_input],
outputs=audio_output
)
logger.info("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
except Exception as e:
logger.error(f"Failed to start the application: {str(e)}")
raise