WYBar commited on
Commit
d835c19
·
1 Parent(s): 907ad49

.to in the @spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -362,13 +362,15 @@ def construction_all():
362
  @spaces.GPU(duration=120)
363
  def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
364
  print(f"evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
 
 
365
  json_example = inputs
366
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
367
 
368
  print("tokenizer1")
369
  inputs = tokenizer(
370
  input_intension, return_tensors="pt"
371
- ).to(model.lm.device)
372
  print("Input IDs device:", inputs["input_ids"].device)
373
  print("Attention Mask device:", inputs["attention_mask"].device)
374
  print("tokenizer2")
@@ -412,6 +414,7 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
412
  max_try_time = 5
413
  preddata = None
414
  while preddata is None and max_try_time > 0:
 
415
  preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
416
  max_try_time -= 1
417
  else:
 
362
  @spaces.GPU(duration=120)
363
  def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
364
  print(f"evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
365
+ model = model.to("cuda")
366
+ print(f"after evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
367
  json_example = inputs
368
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
369
 
370
  print("tokenizer1")
371
  inputs = tokenizer(
372
  input_intension, return_tensors="pt"
373
+ ).to("cuda")
374
  print("Input IDs device:", inputs["input_ids"].device)
375
  print("Attention Mask device:", inputs["attention_mask"].device)
376
  print("tokenizer2")
 
414
  max_try_time = 5
415
  preddata = None
416
  while preddata is None and max_try_time > 0:
417
+ print(f"inference {model.device} {model.lm.device} {pipeline.device}")
418
  preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
419
  max_try_time -= 1
420
  else: