Spaces:
Running
Running
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import gradio as gr | |
| import soundfile as sf | |
| import numpy as np | |
| import os | |
| from io import BytesIO | |
| import base64 | |
| import spaces | |
| # Model and Tokenizer Loading | |
| MODEL_ID = "Qwen/Qwen-Audio-Chat" | |
| def load_model(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| chat_template = """<s>[INST] <<SYS>> | |
| You are a helpful assistant. | |
| <</SYS>> | |
| {% for message in messages %} | |
| {{ message['role'] }}: {{ message['content'] }} | |
| {% endfor %}[/INST]""" | |
| tokenizer.chat_template = chat_template | |
| return model, tokenizer | |
| def process_audio(audio_path): | |
| try: | |
| audio_data, sample_rate = sf.read(audio_path) | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.mean(axis=1) | |
| audio_data = audio_data.astype(np.float32) | |
| audio_buffer = BytesIO() | |
| sf.write(audio_buffer, audio_data, sample_rate, format='WAV') | |
| audio_buffer.seek(0) | |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
| return { | |
| "audio": audio_base64, | |
| "sampling_rate": sample_rate | |
| } | |
| except Exception: | |
| return None | |
| def analyze_audio(audio_path: str, question: str = None) -> str: | |
| if audio_path is None or not isinstance(audio_path, str): | |
| return "Please provide a valid audio file." | |
| if not os.path.exists(audio_path): | |
| return f"Audio file not found: {audio_path}" | |
| audio_data = process_audio(audio_path) | |
| if not audio_data or "audio" not in audio_data or "sampling_rate" not in audio_data: | |
| return "Failed to process the audio file. Please ensure it's a valid audio format." | |
| try: | |
| model, tokenizer = load_model() | |
| query = question if question else "Please describe what you hear in this audio clip." | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": query | |
| } | |
| ] | |
| if tokenizer.chat_template: | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| else: | |
| raise ValueError("Tokenizer chat_template is not set.") | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| if outputs is None or len(outputs) == 0: | |
| return "The model failed to generate a response. Please try again." | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| except Exception: | |
| return "An error occurred while processing. Please check your inputs and try again." | |
| demo = gr.Interface( | |
| fn=analyze_audio, | |
| inputs=[ | |
| gr.Audio( | |
| type="filepath", | |
| label="Audio Input", | |
| sources=["upload", "microphone"], | |
| format="mp3" | |
| ), | |
| gr.Textbox( | |
| label="Question", | |
| placeholder="Optional: Ask a specific question about the audio", | |
| value="" | |
| ) | |
| ], | |
| outputs=gr.Textbox(label="Analysis"), | |
| title="Qwen Audio Analysis Tool", | |
| description="Upload an audio file or record from microphone to get AI-powered analysis using Qwen-Audio-Chat model", | |
| examples=[ | |
| ["example1.wav", "What instruments do you hear?"] | |
| ], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |