Serhiy Stetskovych commited on
Commit
1b8633f
·
1 Parent(s): ea8d6db

Fix model select

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -47,12 +47,18 @@ texts
47
  apollo_config = get_config('configs/apollo.yaml')
48
  apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device)
49
 
50
- models = [
51
- ('MP3 restore', apollo_model)
 
 
 
 
52
  ]
53
 
54
  @spaces.GPU
55
- def enchance(model, audio):
 
 
56
  test_data, samplerate = load_audio(audio)
57
  C = 10 * samplerate # chunk_size seconds to samples
58
  N = 2
@@ -122,7 +128,7 @@ if __name__ == "__main__":
122
  fn=enchance,
123
  description=description,
124
  inputs=[
125
- gr.Dropdown(label="Model", choices=models, value=models[0]),
126
  gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_progress_color': '#3C82F6'}),
127
  ],
128
  outputs=[
 
47
  apollo_config = get_config('configs/apollo.yaml')
48
  apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device)
49
 
50
+ models = {
51
+ 'apollo': apollo_model
52
+ }
53
+
54
+ choices = [
55
+ ('MP3 restore', 'apollo')
56
  ]
57
 
58
  @spaces.GPU
59
+ def enchance(choice, audio):
60
+ print(choice)
61
+ model = models[choice]
62
  test_data, samplerate = load_audio(audio)
63
  C = 10 * samplerate # chunk_size seconds to samples
64
  N = 2
 
128
  fn=enchance,
129
  description=description,
130
  inputs=[
131
+ gr.Dropdown(label="Model", choices=choices, value=choices[0]),
132
  gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_progress_color': '#3C82F6'}),
133
  ],
134
  outputs=[