File size: 3,514 Bytes
df835ed
 
7abb7ba
df835ed
 
 
 
 
 
7abb7ba
df835ed
 
7abb7ba
 
 
 
 
 
 
 
e55ad0a
7abb7ba
df835ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7abb7ba
df835ed
7abb7ba
df835ed
7abb7ba
df835ed
7abb7ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df835ed
7abb7ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df835ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
from threading import Thread
import uuid
import soundfile as sf
import numpy as np
from transformers.generation import TextIteratorStreamer

# Model and Tokenizer Loading
MODEL_ID = "Qwen/Qwen-Audio-Chat"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

DESCRIPTION = "[Qwen-Audio-Chat Demo](https://huggingface.co/Qwen/Qwen-Audio-Chat)"

audio_extensions = (".wav", ".mp3", ".ogg", ".flac")

def process_audio(audio_path):
    """Process audio file and return the appropriate format for the model."""
    audio_data, sample_rate = sf.read(audio_path)
    if len(audio_data.shape) > 1:
        audio_data = audio_data.mean(axis=1)  # Convert stereo to mono if necessary
    return audio_data, sample_rate

@spaces.GPU
def qwen_inference(audio_input, text_input=None):
    if not isinstance(audio_input, str) or not audio_input.lower().endswith(audio_extensions):
        raise ValueError("Please upload a valid audio file (WAV, MP3, OGG, or FLAC)")

    # Process audio input
    audio_data, sample_rate = process_audio(audio_input)
    
    # Prepare the messages
    if text_input:
        query = text_input
    else:
        query = "Please describe what you hear in this audio clip."

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "audio",
                    "audio": audio_input,
                },
                {
                    "type": "text",
                    "text": query,
                },
            ],
        }
    ]

    # Convert messages to model input format
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Set up streamer for real-time output
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=512,
        temperature=0.7,
        do_sample=True
    )

    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Stream the output
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Tab(label="Audio Input"):
        with gr.Row():
            with gr.Column():
                input_audio = gr.Audio(
                    label="Upload Audio",
                    type="filepath"
                )
                text_input = gr.Textbox(
                    label="Question (optional)",
                    placeholder="Ask a question about the audio or leave empty for general description"
                )
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_text = gr.Textbox(label="Output Text")

        submit_btn.click(
            qwen_inference,
            [input_audio, text_input],
            [output_text]
        )

demo.launch(debug=True)