asigalov61 commited on
Commit
3eaad6d
·
verified ·
1 Parent(s): 20bc1d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -111,7 +111,7 @@ def load_model(model_selector):
111
  print('Model will use', dtype, 'precision...')
112
  print('=' * 70)
113
 
114
- return model
115
 
116
  #==================================================================================
117
 
@@ -207,7 +207,8 @@ def generate_music(prime,
207
  else:
208
  inputs = prime[-num_mem_tokens:]
209
 
210
- model = model_state
 
211
 
212
  model.cuda()
213
  model.eval()
@@ -219,7 +220,6 @@ def generate_music(prime,
219
  inp = torch.LongTensor(inp).cuda()
220
 
221
  with ctx:
222
- with torch.inference_mode():
223
  out = model.generate(inp,
224
  num_gen_tokens,
225
  #filter_logits_fn=top_p,
 
111
  print('Model will use', dtype, 'precision...')
112
  print('=' * 70)
113
 
114
+ return [model, ctx]
115
 
116
  #==================================================================================
117
 
 
207
  else:
208
  inputs = prime[-num_mem_tokens:]
209
 
210
+ model = model_state[0]
211
+ ctx = model_state[1]
212
 
213
  model.cuda()
214
  model.eval()
 
220
  inp = torch.LongTensor(inp).cuda()
221
 
222
  with ctx:
 
223
  out = model.generate(inp,
224
  num_gen_tokens,
225
  #filter_logits_fn=top_p,