ArabicLAWLLM / app.py
ghostai1's picture
Update app.py
8579576 verified
import gradio as gr
import torch
import logging
from transformers import AutoTokenizer, AutoModel
from diffusers import DiffusionPipeline
import soundfile as sf
import numpy as np
# Set up logging to debug startup issues
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
# Load text tokenizer and embedding model (umt5-base)
def load_text_processor():
logger.info("Loading text processor (umt5-base)...")
tokenizer = AutoTokenizer.from_pretrained("./umt5-base")
text_model = AutoModel.from_pretrained(
"./umt5-base",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("Text processor loaded successfully.")
return tokenizer, text_model
# Load the transformer backbone (phantomstep_transformer)
def load_transformer():
logger.info("Loading transformer (phantomstep_transformer)...")
transformer = DiffusionPipeline.from_pretrained(
"./phantomstep_transformer",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("Transformer loaded successfully.")
return transformer
# Load the DCAE for audio encoding/decoding (phantomstep_dcae)
def load_dcae():
logger.info("Loading DCAE (phantomstep_dcae)...")
dcae = DiffusionPipeline.from_pretrained(
"./phantomstep_dcae",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("DCAE loaded successfully.")
return dcae
# Load the vocoder for audio synthesis (phantomstep_vocoder)
def load_vocoder():
logger.info("Loading vocoder (phantomstep_vocoder)...")
vocoder = DiffusionPipeline.from_pretrained(
"./phantomstep_vocoder",
use_safetensors=True,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("Vocoder loaded successfully.")
return vocoder
# Generate music from a text prompt
def generate_music(prompt, duration=20, seed=42):
logger.info(f"Generating music with prompt: {prompt}, duration: {duration}, seed: {seed}")
torch.manual_seed(seed)
# Load all components
tokenizer, text_model = load_text_processor()
transformer = load_transformer()
dcae = load_dcae()
vocoder = load_vocoder()
# Step 1: Process text prompt to embeddings
logger.info("Processing text prompt to embeddings...")
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(text_model.device) for k, v in inputs.items()}
with torch.no_grad():
embeddings = text_model(**inputs).last_hidden_state.mean(dim=1)
# Step 2: Pass embeddings through transformer
logger.info("Generating with transformer...")
transformer_output = transformer(
embeddings,
num_inference_steps=50,
audio_length_in_s=duration
).audios[0]
# Step 3: Decode audio features with DCAE
logger.info("Decoding with DCAE...")
dcae_output = dcae(
transformer_output,
num_inference_steps=50,
audio_length_in_s=duration
).audios[0]
# Step 4: Synthesize final audio with vocoder
logger.info("Synthesizing with vocoder...")
audio = vocoder(
dcae_output,
num_inference_steps=50,
audio_length_in_s=duration
).audios[0]
# Save audio to a file
output_path = "output.wav"
sf.write(output_path, audio, 22050) # 22kHz sample rate
logger.info("Music generation complete.")
return output_path
# Gradio interface
logger.info("Setting up Gradio interface...")
with gr.Blocks(title="PhantomStep: Text-to-Music Generation 🎡") as demo:
gr.Markdown("# PhantomStep by GhostAI πŸš€")
gr.Markdown("Enter a text prompt to generate music! 🎢")
prompt_input = gr.Textbox(label="Text Prompt", placeholder="A jazzy piano melody with a fast tempo")
duration_input = gr.Slider(label="Duration (seconds)", minimum=10, maximum=60, value=20, step=1)
seed_input = gr.Number(label="Random Seed", value=42, precision=0)
generate_button = gr.Button("Generate Music")
audio_output = gr.Audio(label="Generated Music")
generate_button.click(
fn=generate_music,
inputs=[prompt_input, duration_input, seed_input],
outputs=audio_output
)
logger.info("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
except Exception as e:
logger.error(f"Failed to start the application: {str(e)}")
raise