jadechoghari commited on
Commit
7e40a31
·
verified ·
1 Parent(s): 2879448

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -22
app.py CHANGED
@@ -1,55 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Tuple, Union
 
2
  import gradio as gr
3
- import os
4
- from PIL import Image
5
  import spaces
 
 
 
 
 
 
6
 
7
- CACHE_DIR = "gradio_cached_examples"
 
 
 
 
 
 
8
 
 
 
9
 
 
10
  def load_cached_example_outputs(example_index: int) -> Tuple[str, str]:
11
  cached_dir = os.path.join(CACHE_DIR, str(example_index)) # Use the example index to find the directory
12
  cached_image_path = os.path.join(cached_dir, "processed_image.png")
13
  cached_audio_path = os.path.join(cached_dir, "audio.wav")
14
 
 
15
  if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path):
16
  return cached_image_path, cached_audio_path
17
  else:
18
  raise FileNotFoundError(f"Cached outputs not found for example {example_index}")
19
 
20
- description_text = """# SEE-2-SOUND 🔊 Demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
22
  Official demo for *SEE-2-SOUND 🔊: Zero-Shot Spatial Environment-to-Spatial Sound*.
 
 
 
23
  """
24
 
25
  css = """
26
- h1 { text-align: center; }
 
 
27
  """
28
 
29
- @spaces.GPU
30
  with gr.Blocks(css=css) as demo:
31
  gr.Markdown(description_text)
32
 
33
  with gr.Row():
34
  with gr.Column():
35
- image = gr.Image(label="Select an image", sources=["upload", "webcam"], type="filepath")
 
 
36
 
37
  with gr.Accordion("Advanced Settings", open=False):
38
- steps = gr.Slider(label="Diffusion Steps", minimum=1, maximum=1000, step=1, value=500)
39
- prompt = gr.Text(label="Prompt", max_lines=1, placeholder="Enter your prompt")
40
- num_audios = gr.Slider(label="Number of Audios", minimum=1, maximum=10, step=1, value=3)
 
 
 
 
 
 
 
 
 
 
41
 
42
  submit_button = gr.Button("Submit")
43
 
44
  with gr.Column():
45
  processed_image = gr.Image(label="Processed Image")
46
- generated_audio = gr.Audio(label="Generated Audio", show_download_button=True)
47
-
48
-
49
- def on_example_click(*args, **kwargs):
50
- return load_cached_example_outputs(1) # Always load example 1 for now
51
-
52
-
 
 
 
53
 
54
  gr.Examples(
55
  examples=[["examples/1.png", 3, "A scenic mountain view", 500]], # Example input
@@ -59,12 +157,13 @@ with gr.Blocks(css=css) as demo:
59
  fn=on_example_click # Load the cached output when the example is clicked
60
  )
61
 
62
-
63
- submit_button.click(
64
- fn=on_example_click,
65
  inputs=[image, num_audios, prompt, steps],
66
- outputs=[processed_image, generated_audio]
67
  )
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Pricing
10
+
11
+
12
+
13
+ Spaces:
14
+
15
+ rishitdagli
16
+ /
17
+ see-2-sound
18
+
19
+
20
+ like
21
+ 18
22
+ App
23
+ Files
24
+ Community
25
+ see-2-sound
26
+ /
27
+ app.py
28
+
29
+ rishitdagli's picture
30
+ rishitdagli
31
+ Update app.py
32
+ 852e4aa
33
+ verified
34
+ about 14 hours ago
35
+ raw
36
+
37
+ Copy download link
38
+ history
39
+ blame
40
+ contribute
41
+ delete
42
+
43
+ 3.41 kB
44
  from typing import Tuple, Union
45
+
46
  import gradio as gr
47
+ import numpy as np
48
+ import see2sound
49
  import spaces
50
+ import torch
51
+ import yaml
52
+ from huggingface_hub import snapshot_download
53
+
54
+ model_id = "rishitdagli/see-2-sound"
55
+ base_path = snapshot_download(repo_id=model_id)
56
 
57
+ with open("config.yaml", "r") as file:
58
+ data = yaml.safe_load(file)
59
+ data_str = yaml.dump(data)
60
+ updated_data_str = data_str.replace("checkpoints", base_path)
61
+ updated_data = yaml.safe_load(updated_data_str)
62
+ with open("config.yaml", "w") as file:
63
+ yaml.safe_dump(updated_data, file)
64
 
65
+ model = see2sound.See2Sound(config_path="config.yaml")
66
+ model.setup()
67
 
68
+ #for local cache
69
  def load_cached_example_outputs(example_index: int) -> Tuple[str, str]:
70
  cached_dir = os.path.join(CACHE_DIR, str(example_index)) # Use the example index to find the directory
71
  cached_image_path = os.path.join(cached_dir, "processed_image.png")
72
  cached_audio_path = os.path.join(cached_dir, "audio.wav")
73
 
74
+ # Ensure cached files exist
75
  if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path):
76
  return cached_image_path, cached_audio_path
77
  else:
78
  raise FileNotFoundError(f"Cached outputs not found for example {example_index}")
79
 
80
+ # Function to handle the example click, it now accepts arbitrary arguments
81
+ def on_example_click(*args, **kwargs):
82
+ return load_cached_example_outputs(1) # Always load example 1 for now
83
+
84
+
85
+ @spaces.GPU(duration=280)
86
+ @torch.no_grad()
87
+ def process_image(
88
+ image: str, num_audios: int, prompt: Union[str, None], steps: Union[int, None]
89
+ ) -> Tuple[str, str]:
90
+ model.run(
91
+ path=image,
92
+ output_path="audio.wav",
93
+ num_audios=num_audios,
94
+ prompt=prompt,
95
+ steps=steps,
96
+ )
97
+ return image, "audio.wav"
98
 
99
+
100
+ description_text = """# SEE-2-SOUND 🔊 Demo
101
  Official demo for *SEE-2-SOUND 🔊: Zero-Shot Spatial Environment-to-Spatial Sound*.
102
+ Please refer to our [paper](https://arxiv.org/abs/2406.06612), [project page](https://see2sound.github.io/), or [github](https://github.com/see2sound/see2sound) for more details.
103
+ > Note: You should make sure that your hardware supports spatial audio.
104
+ This demo allows you to generate spatial audio given an image. Upload an image (with an optional text prompt in the advanced settings) to geenrate spatial audio to accompany the image.
105
  """
106
 
107
  css = """
108
+ h1 {
109
+ text-align: center;
110
+ }
111
  """
112
 
 
113
  with gr.Blocks(css=css) as demo:
114
  gr.Markdown(description_text)
115
 
116
  with gr.Row():
117
  with gr.Column():
118
+ image = gr.Image(
119
+ label="Select an image", sources=["upload", "webcam"], type="filepath"
120
+ )
121
 
122
  with gr.Accordion("Advanced Settings", open=False):
123
+ steps = gr.Slider(
124
+ label="Diffusion Steps", minimum=1, maximum=1000, step=1, value=500
125
+ )
126
+ prompt = gr.Text(
127
+ label="Prompt",
128
+ show_label=True,
129
+ max_lines=1,
130
+ placeholder="Enter your prompt",
131
+ container=True,
132
+ )
133
+ num_audios = gr.Slider(
134
+ label="Number of Audios", minimum=1, maximum=10, step=1, value=3
135
+ )
136
 
137
  submit_button = gr.Button("Submit")
138
 
139
  with gr.Column():
140
  processed_image = gr.Image(label="Processed Image")
141
+ generated_audio = gr.Audio(
142
+ label="Generated Audio",
143
+ show_download_button=True,
144
+ show_share_button=True,
145
+ waveform_options=gr.WaveformOptions(
146
+ waveform_color="#01C6FF",
147
+ waveform_progress_color="#0066B4",
148
+ show_controls=True,
149
+ ),
150
+ )
151
 
152
  gr.Examples(
153
  examples=[["examples/1.png", 3, "A scenic mountain view", 500]], # Example input
 
157
  fn=on_example_click # Load the cached output when the example is clicked
158
  )
159
 
160
+ gr.on(
161
+ triggers=[submit_button.click],
162
+ fn=process_image,
163
  inputs=[image, num_audios, prompt, steps],
164
+ outputs=[processed_image, generated_audio],
165
  )
166
 
167
  if __name__ == "__main__":
168
+ demo.launch()
169
+