annabeth97c commited on
Commit
db52c84
·
verified ·
1 Parent(s): 6733da5

Add caption chaining feature, captioning chunks and chaining with openai if input is longer than 10 seconds

Browse files
Files changed (1) hide show
  1. app.py +82 -5
app.py CHANGED
@@ -18,11 +18,17 @@ import torch
18
  import transformers
19
  import torchaudio
20
 
 
 
 
 
 
21
  from multi_token.model_utils import MultiTaskType
22
  from multi_token.training import ModelArguments
23
  from multi_token.inference import load_trained_lora_model
24
  from multi_token.data_tools import encode_chat
25
 
 
26
 
27
  @dataclass
28
  class ServeArguments(ModelArguments):
@@ -31,7 +37,6 @@ class ServeArguments(ModelArguments):
31
  temperature: float = field(default=0.01)
32
 
33
 
34
- # Load arguments and model
35
  logging.getLogger().setLevel(logging.INFO)
36
 
37
  parser = transformers.HfArgumentParser((ServeArguments,))
@@ -45,10 +50,82 @@ model, tokenizer = load_trained_lora_model(
45
  tasks_config=serve_args.tasks_config
46
  )
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def generate_caption(audio_file):
50
- # waveform, sample_rate = torchaudio.load(audio_file)
51
-
52
  req_json = {
53
  "messages": [
54
  {"role": "user", "content": "Describe the music. <sound>"}
@@ -79,7 +156,7 @@ def generate_caption(audio_file):
79
 
80
 
81
  demo = gr.Interface(
82
- fn=generate_caption,
83
  inputs=gr.Audio(type="filepath", label="Upload an audio file"),
84
  outputs=gr.Textbox(label="Generated Caption"),
85
  title="SonicVerse",
@@ -87,4 +164,4 @@ demo = gr.Interface(
87
  )
88
 
89
  if __name__ == "__main__":
90
- demo.launch()
 
18
  import transformers
19
  import torchaudio
20
 
21
+ from openai import OpenAI
22
+ client = OpenAI()
23
+ MODEL = "gpt-4"
24
+ SLEEP_BETWEEN_CALLS = 1.0
25
+
26
  from multi_token.model_utils import MultiTaskType
27
  from multi_token.training import ModelArguments
28
  from multi_token.inference import load_trained_lora_model
29
  from multi_token.data_tools import encode_chat
30
 
31
+ CHUNK_LENGTH = 10
32
 
33
  @dataclass
34
  class ServeArguments(ModelArguments):
 
37
  temperature: float = field(default=0.01)
38
 
39
 
 
40
  logging.getLogger().setLevel(logging.INFO)
41
 
42
  parser = transformers.HfArgumentParser((ServeArguments,))
 
50
  tasks_config=serve_args.tasks_config
51
  )
52
 
53
+ def caption_audio(audio_file):
54
+ chunk_audio_files = split_audio(audio_file, CHUNK_LENGTH)
55
+ chunk_captions = []
56
+ for audio_chunk in chunk_audio_files:
57
+ chunk_captions.append(generate_caption(audio_chunk))
58
+
59
+ if len(chunk_captions) > 1:
60
+ audio_name = os.path.splitext(os.path.basename(audio_file))[0]
61
+ long_caption = summarize_song(audio_name, chunk_captions)
62
+
63
+ delete_files(chunk_audio_files)
64
+
65
+ return long_caption
66
+
67
+ else:
68
+ if len(chunk_captions) == 1:
69
+ return chunk_captions[0]
70
+ else:
71
+ return ""
72
+
73
+ def summarize_song(song_name, chunks):
74
+ prompt = f"""
75
+ You are a music critic. Given the following chronological 10‑second chunk descriptions of a single piece, write one flowing, detailed description of the entire song—its structure, instrumentation, and standout moments. Mention transition points in terms of time stamps. If the description of certain chunks does not seem to fit with those for the chunks before and after, treat those as bad descriptions with lower accuracy and do not incorporate the information. Retain concrete musical attributes such as key, chords, tempo.
76
+
77
+ Chunks for “{song_name} ”:
78
+ """
79
+ for i, c in enumerate(chunks, 1):
80
+ prompt += f"\n {(i - 1)*0} to {i*10} seconds. {c.strip()}"
81
+ prompt += "\n\nFull song description:"
82
+
83
+ resp = client.chat.completions.create(model=MODEL,
84
+ messages=[
85
+ {"role": "system", "content": "You are an expert music writer."},
86
+ {"role": "user", "content": prompt}
87
+ ],
88
+ temperature=0.0,
89
+ max_tokens=1000)
90
+ return resp.choices[0].message.content.strip()
91
+
92
+ def delete_files(file_paths):
93
+ for path in file_paths:
94
+ try:
95
+ if os.path.isfile(path):
96
+ os.remove(path)
97
+ print(f"Deleted: {path}")
98
+ else:
99
+ print(f"Skipped (not a file or doesn't exist): {path}")
100
+ except Exception as e:
101
+ print(f"Error deleting {path}: {e}")
102
+
103
+ def split_audio(input_path, chunk_length_seconds):
104
+
105
+ waveform, sample_rate = torchaudio.load(input_path)
106
+ num_channels, total_samples = waveform.shape
107
+ chunk_samples = int(chunk_length_seconds * sample_rate)
108
+
109
+ num_chunks = (total_samples + chunk_samples - 1) // chunk_samples
110
+
111
+ base, ext = os.path.splitext(input_path)
112
+ output_paths = []
113
+
114
+ if (num_chunks <= 1):
115
+ return [input_path]
116
+
117
+ for i in range(num_chunks):
118
+ start = i * chunk_samples
119
+ end = min((i + 1) * chunk_samples, total_samples)
120
+ chunk_waveform = waveform[:, start:end]
121
+
122
+ output_file = f"{base}_{i+1:03d}{ext}"
123
+ torchaudio.save(output_file, chunk_waveform, sample_rate)
124
+ output_paths.append(output_file)
125
+
126
+ return output_paths
127
 
128
  def generate_caption(audio_file):
 
 
129
  req_json = {
130
  "messages": [
131
  {"role": "user", "content": "Describe the music. <sound>"}
 
156
 
157
 
158
  demo = gr.Interface(
159
+ fn=caption_audio,
160
  inputs=gr.Audio(type="filepath", label="Upload an audio file"),
161
  outputs=gr.Textbox(label="Generated Caption"),
162
  title="SonicVerse",
 
164
  )
165
 
166
  if __name__ == "__main__":
167
+ demo.launch()