Spaces:
Runtime error
Runtime error
Commit
·
e77fc2d
1
Parent(s):
686c4ae
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,7 +20,6 @@ import torchvision.transforms as transforms
|
|
| 20 |
import av
|
| 21 |
import subprocess
|
| 22 |
import librosa
|
| 23 |
-
import re
|
| 24 |
|
| 25 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
| 26 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
|
@@ -34,6 +33,8 @@ class dotdict(dict):
|
|
| 34 |
|
| 35 |
args = dotdict(args)
|
| 36 |
|
|
|
|
|
|
|
| 37 |
llama_type = args.llama_type
|
| 38 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
| 39 |
llama_tokenzier_path = args.llama_dir
|
|
@@ -117,6 +118,7 @@ def parse_text(text, image_path, video_path, audio_path):
|
|
| 117 |
|
| 118 |
|
| 119 |
def save_audio_to_local(audio, sec):
|
|
|
|
| 120 |
if not os.path.exists('temp'):
|
| 121 |
os.mkdir('temp')
|
| 122 |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
|
|
@@ -124,6 +126,7 @@ def save_audio_to_local(audio, sec):
|
|
| 124 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
| 125 |
else:
|
| 126 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
|
|
|
| 127 |
return filename
|
| 128 |
|
| 129 |
|
|
@@ -159,10 +162,14 @@ def reset_user_input():
|
|
| 159 |
|
| 160 |
|
| 161 |
def reset_dialog():
|
|
|
|
|
|
|
| 162 |
return [], []
|
| 163 |
|
| 164 |
|
| 165 |
def reset_state():
|
|
|
|
|
|
|
| 166 |
return None, None, None, None, [], [], []
|
| 167 |
|
| 168 |
|
|
@@ -209,12 +216,6 @@ def get_video_length(filename):
|
|
| 209 |
def get_audio_length(filename):
|
| 210 |
return int(round(librosa.get_duration(path=filename)))
|
| 211 |
|
| 212 |
-
def get_last_audio():
|
| 213 |
-
for hist in history[::-1]:
|
| 214 |
-
print(hist)
|
| 215 |
-
if "<audio controls playsinline>" in hist[1]:
|
| 216 |
-
return re.search('<audio controls playsinline><source src=\"\.\/file=(.*)\" type="audio\/wav"><\/audio>', hist[1]).group(1)
|
| 217 |
-
return None
|
| 218 |
|
| 219 |
def predict(
|
| 220 |
prompt_input,
|
|
@@ -227,6 +228,7 @@ def predict(
|
|
| 227 |
history,
|
| 228 |
modality_cache,
|
| 229 |
audio_length_in_s):
|
|
|
|
| 230 |
prompts = [llama.format_prompt(prompt_input)]
|
| 231 |
prompts = [model.tokenizer(x).input_ids for x in prompts]
|
| 232 |
print(image_path, audio_path, video_path)
|
|
@@ -244,11 +246,11 @@ def predict(
|
|
| 244 |
container = av.open(video_path)
|
| 245 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
| 246 |
video = read_video_pyav(container=container, indices=indices)
|
| 247 |
-
|
| 248 |
-
if
|
| 249 |
-
audio_length_in_s = get_audio_length(
|
| 250 |
sample_rate = 24000
|
| 251 |
-
waveform, sr = torchaudio.load(
|
| 252 |
if sample_rate != sr:
|
| 253 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
| 254 |
audio = torch.mean(waveform, 0)
|
|
@@ -259,6 +261,7 @@ def predict(
|
|
| 259 |
print(f"Video Length: {audio_length_in_s}")
|
| 260 |
if audio_path is not None:
|
| 261 |
audio_length_in_s = get_audio_length(audio_path)
|
|
|
|
| 262 |
print(f"Audio Length: {audio_length_in_s}")
|
| 263 |
|
| 264 |
print(image, video, audio)
|
|
@@ -350,4 +353,4 @@ with gr.Blocks() as demo:
|
|
| 350 |
], show_progress=True)
|
| 351 |
|
| 352 |
if __name__ == "__main__":
|
| 353 |
-
demo.launch()
|
|
|
|
| 20 |
import av
|
| 21 |
import subprocess
|
| 22 |
import librosa
|
|
|
|
| 23 |
|
| 24 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
| 25 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
|
|
|
| 33 |
|
| 34 |
args = dotdict(args)
|
| 35 |
|
| 36 |
+
generated_audio_files = []
|
| 37 |
+
|
| 38 |
llama_type = args.llama_type
|
| 39 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
| 40 |
llama_tokenzier_path = args.llama_dir
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
def save_audio_to_local(audio, sec):
|
| 121 |
+
global generated_audio_files
|
| 122 |
if not os.path.exists('temp'):
|
| 123 |
os.mkdir('temp')
|
| 124 |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
|
|
|
|
| 126 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
| 127 |
else:
|
| 128 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
| 129 |
+
generated_audio_files.append(filename)
|
| 130 |
return filename
|
| 131 |
|
| 132 |
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
def reset_dialog():
|
| 165 |
+
global generated_audio_files
|
| 166 |
+
generated_audio_files = []
|
| 167 |
return [], []
|
| 168 |
|
| 169 |
|
| 170 |
def reset_state():
|
| 171 |
+
global generated_audio_files
|
| 172 |
+
generated_audio_files = []
|
| 173 |
return None, None, None, None, [], [], []
|
| 174 |
|
| 175 |
|
|
|
|
| 216 |
def get_audio_length(filename):
|
| 217 |
return int(round(librosa.get_duration(path=filename)))
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
def predict(
|
| 221 |
prompt_input,
|
|
|
|
| 228 |
history,
|
| 229 |
modality_cache,
|
| 230 |
audio_length_in_s):
|
| 231 |
+
global generated_audio_files
|
| 232 |
prompts = [llama.format_prompt(prompt_input)]
|
| 233 |
prompts = [model.tokenizer(x).input_ids for x in prompts]
|
| 234 |
print(image_path, audio_path, video_path)
|
|
|
|
| 246 |
container = av.open(video_path)
|
| 247 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
| 248 |
video = read_video_pyav(container=container, indices=indices)
|
| 249 |
+
|
| 250 |
+
if len(generated_audio_files) != 0:
|
| 251 |
+
audio_length_in_s = get_audio_length(generated_audio_files[-1])
|
| 252 |
sample_rate = 24000
|
| 253 |
+
waveform, sr = torchaudio.load(generated_audio_files[-1])
|
| 254 |
if sample_rate != sr:
|
| 255 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
| 256 |
audio = torch.mean(waveform, 0)
|
|
|
|
| 261 |
print(f"Video Length: {audio_length_in_s}")
|
| 262 |
if audio_path is not None:
|
| 263 |
audio_length_in_s = get_audio_length(audio_path)
|
| 264 |
+
generated_audio_files.append(audio_path)
|
| 265 |
print(f"Audio Length: {audio_length_in_s}")
|
| 266 |
|
| 267 |
print(image, video, audio)
|
|
|
|
| 353 |
], show_progress=True)
|
| 354 |
|
| 355 |
if __name__ == "__main__":
|
| 356 |
+
demo.launch()
|