Spaces:
Sleeping
Sleeping
Add caption chaining feature, captioning chunks and chaining with openai if input is longer than 10 seconds
Browse files
app.py
CHANGED
@@ -18,11 +18,17 @@ import torch
|
|
18 |
import transformers
|
19 |
import torchaudio
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
from multi_token.model_utils import MultiTaskType
|
22 |
from multi_token.training import ModelArguments
|
23 |
from multi_token.inference import load_trained_lora_model
|
24 |
from multi_token.data_tools import encode_chat
|
25 |
|
|
|
26 |
|
27 |
@dataclass
|
28 |
class ServeArguments(ModelArguments):
|
@@ -31,7 +37,6 @@ class ServeArguments(ModelArguments):
|
|
31 |
temperature: float = field(default=0.01)
|
32 |
|
33 |
|
34 |
-
# Load arguments and model
|
35 |
logging.getLogger().setLevel(logging.INFO)
|
36 |
|
37 |
parser = transformers.HfArgumentParser((ServeArguments,))
|
@@ -45,10 +50,82 @@ model, tokenizer = load_trained_lora_model(
|
|
45 |
tasks_config=serve_args.tasks_config
|
46 |
)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def generate_caption(audio_file):
|
50 |
-
# waveform, sample_rate = torchaudio.load(audio_file)
|
51 |
-
|
52 |
req_json = {
|
53 |
"messages": [
|
54 |
{"role": "user", "content": "Describe the music. <sound>"}
|
@@ -79,7 +156,7 @@ def generate_caption(audio_file):
|
|
79 |
|
80 |
|
81 |
demo = gr.Interface(
|
82 |
-
fn=
|
83 |
inputs=gr.Audio(type="filepath", label="Upload an audio file"),
|
84 |
outputs=gr.Textbox(label="Generated Caption"),
|
85 |
title="SonicVerse",
|
@@ -87,4 +164,4 @@ demo = gr.Interface(
|
|
87 |
)
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
-
demo.launch()
|
|
|
18 |
import transformers
|
19 |
import torchaudio
|
20 |
|
21 |
+
from openai import OpenAI
|
22 |
+
client = OpenAI()
|
23 |
+
MODEL = "gpt-4"
|
24 |
+
SLEEP_BETWEEN_CALLS = 1.0
|
25 |
+
|
26 |
from multi_token.model_utils import MultiTaskType
|
27 |
from multi_token.training import ModelArguments
|
28 |
from multi_token.inference import load_trained_lora_model
|
29 |
from multi_token.data_tools import encode_chat
|
30 |
|
31 |
+
CHUNK_LENGTH = 10
|
32 |
|
33 |
@dataclass
|
34 |
class ServeArguments(ModelArguments):
|
|
|
37 |
temperature: float = field(default=0.01)
|
38 |
|
39 |
|
|
|
40 |
logging.getLogger().setLevel(logging.INFO)
|
41 |
|
42 |
parser = transformers.HfArgumentParser((ServeArguments,))
|
|
|
50 |
tasks_config=serve_args.tasks_config
|
51 |
)
|
52 |
|
53 |
+
def caption_audio(audio_file):
|
54 |
+
chunk_audio_files = split_audio(audio_file, CHUNK_LENGTH)
|
55 |
+
chunk_captions = []
|
56 |
+
for audio_chunk in chunk_audio_files:
|
57 |
+
chunk_captions.append(generate_caption(audio_chunk))
|
58 |
+
|
59 |
+
if len(chunk_captions) > 1:
|
60 |
+
audio_name = os.path.splitext(os.path.basename(audio_file))[0]
|
61 |
+
long_caption = summarize_song(audio_name, chunk_captions)
|
62 |
+
|
63 |
+
delete_files(chunk_audio_files)
|
64 |
+
|
65 |
+
return long_caption
|
66 |
+
|
67 |
+
else:
|
68 |
+
if len(chunk_captions) == 1:
|
69 |
+
return chunk_captions[0]
|
70 |
+
else:
|
71 |
+
return ""
|
72 |
+
|
73 |
+
def summarize_song(song_name, chunks):
|
74 |
+
prompt = f"""
|
75 |
+
You are a music critic. Given the following chronological 10‑second chunk descriptions of a single piece, write one flowing, detailed description of the entire song—its structure, instrumentation, and standout moments. Mention transition points in terms of time stamps. If the description of certain chunks does not seem to fit with those for the chunks before and after, treat those as bad descriptions with lower accuracy and do not incorporate the information. Retain concrete musical attributes such as key, chords, tempo.
|
76 |
+
|
77 |
+
Chunks for “{song_name} ”:
|
78 |
+
"""
|
79 |
+
for i, c in enumerate(chunks, 1):
|
80 |
+
prompt += f"\n {(i - 1)*0} to {i*10} seconds. {c.strip()}"
|
81 |
+
prompt += "\n\nFull song description:"
|
82 |
+
|
83 |
+
resp = client.chat.completions.create(model=MODEL,
|
84 |
+
messages=[
|
85 |
+
{"role": "system", "content": "You are an expert music writer."},
|
86 |
+
{"role": "user", "content": prompt}
|
87 |
+
],
|
88 |
+
temperature=0.0,
|
89 |
+
max_tokens=1000)
|
90 |
+
return resp.choices[0].message.content.strip()
|
91 |
+
|
92 |
+
def delete_files(file_paths):
|
93 |
+
for path in file_paths:
|
94 |
+
try:
|
95 |
+
if os.path.isfile(path):
|
96 |
+
os.remove(path)
|
97 |
+
print(f"Deleted: {path}")
|
98 |
+
else:
|
99 |
+
print(f"Skipped (not a file or doesn't exist): {path}")
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Error deleting {path}: {e}")
|
102 |
+
|
103 |
+
def split_audio(input_path, chunk_length_seconds):
|
104 |
+
|
105 |
+
waveform, sample_rate = torchaudio.load(input_path)
|
106 |
+
num_channels, total_samples = waveform.shape
|
107 |
+
chunk_samples = int(chunk_length_seconds * sample_rate)
|
108 |
+
|
109 |
+
num_chunks = (total_samples + chunk_samples - 1) // chunk_samples
|
110 |
+
|
111 |
+
base, ext = os.path.splitext(input_path)
|
112 |
+
output_paths = []
|
113 |
+
|
114 |
+
if (num_chunks <= 1):
|
115 |
+
return [input_path]
|
116 |
+
|
117 |
+
for i in range(num_chunks):
|
118 |
+
start = i * chunk_samples
|
119 |
+
end = min((i + 1) * chunk_samples, total_samples)
|
120 |
+
chunk_waveform = waveform[:, start:end]
|
121 |
+
|
122 |
+
output_file = f"{base}_{i+1:03d}{ext}"
|
123 |
+
torchaudio.save(output_file, chunk_waveform, sample_rate)
|
124 |
+
output_paths.append(output_file)
|
125 |
+
|
126 |
+
return output_paths
|
127 |
|
128 |
def generate_caption(audio_file):
|
|
|
|
|
129 |
req_json = {
|
130 |
"messages": [
|
131 |
{"role": "user", "content": "Describe the music. <sound>"}
|
|
|
156 |
|
157 |
|
158 |
demo = gr.Interface(
|
159 |
+
fn=caption_audio,
|
160 |
inputs=gr.Audio(type="filepath", label="Upload an audio file"),
|
161 |
outputs=gr.Textbox(label="Generated Caption"),
|
162 |
title="SonicVerse",
|
|
|
164 |
)
|
165 |
|
166 |
if __name__ == "__main__":
|
167 |
+
demo.launch()
|