Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -111,7 +111,7 @@ def load_model(model_selector):
|
|
111 |
print('Model will use', dtype, 'precision...')
|
112 |
print('=' * 70)
|
113 |
|
114 |
-
return model
|
115 |
|
116 |
#==================================================================================
|
117 |
|
@@ -207,7 +207,8 @@ def generate_music(prime,
|
|
207 |
else:
|
208 |
inputs = prime[-num_mem_tokens:]
|
209 |
|
210 |
-
model = model_state
|
|
|
211 |
|
212 |
model.cuda()
|
213 |
model.eval()
|
@@ -219,7 +220,6 @@ def generate_music(prime,
|
|
219 |
inp = torch.LongTensor(inp).cuda()
|
220 |
|
221 |
with ctx:
|
222 |
-
with torch.inference_mode():
|
223 |
out = model.generate(inp,
|
224 |
num_gen_tokens,
|
225 |
#filter_logits_fn=top_p,
|
|
|
111 |
print('Model will use', dtype, 'precision...')
|
112 |
print('=' * 70)
|
113 |
|
114 |
+
return [model, ctx]
|
115 |
|
116 |
#==================================================================================
|
117 |
|
|
|
207 |
else:
|
208 |
inputs = prime[-num_mem_tokens:]
|
209 |
|
210 |
+
model = model_state[0]
|
211 |
+
ctx = model_state[1]
|
212 |
|
213 |
model.cuda()
|
214 |
model.eval()
|
|
|
220 |
inp = torch.LongTensor(inp).cuda()
|
221 |
|
222 |
with ctx:
|
|
|
223 |
out = model.generate(inp,
|
224 |
num_gen_tokens,
|
225 |
#filter_logits_fn=top_p,
|