zhanghanxiao commited on
Commit
de8bf82
·
verified ·
1 Parent(s): f308a75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -1,8 +1,6 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
2
  from threading import Thread
3
  import gradio as gr
4
- import re
5
- import torch
6
 
7
  # load model and tokenizer
8
  model_name = "inclusionAI/Ling-mini-2.0"
@@ -38,18 +36,28 @@ def respond(
38
  tokenize=False,
39
  add_generation_prompt=True
40
  )
 
 
 
41
  model_inputs = tokenizer([text], return_tensors="pt", return_token_type_ids=False).to(model.device)
42
 
43
- generated_ids = model.generate(
44
- **model_inputs,
45
- max_new_tokens=512
 
 
 
46
  )
47
- generated_ids = [
48
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
49
- ]
 
 
 
 
50
 
51
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
- yield response
53
 
54
 
55
  """
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
2
  from threading import Thread
3
  import gradio as gr
 
 
4
 
5
  # load model and tokenizer
6
  model_name = "inclusionAI/Ling-mini-2.0"
 
36
  tokenize=False,
37
  add_generation_prompt=True
38
  )
39
+
40
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
41
+
42
  model_inputs = tokenizer([text], return_tensors="pt", return_token_type_ids=False).to(model.device)
43
 
44
+ model_inputs.update(dict(max_new_tokens=512,streamer=streamer))
45
+
46
+ # Start a separate thread for model generation to allow streaming output
47
+ thread = Thread(
48
+ target=model.generate,
49
+ kwargs=model_inputs,
50
  )
51
+ thread.start()
52
+
53
+ # Accumulate and yield text tokens as they are generated
54
+ acc_text = ""
55
+ for text_token in streamer:
56
+ acc_text += text_token # Append the generated token to the accumulated text
57
+ yield acc_text # Yield the accumulated text
58
 
59
+ # Ensure the generation thread completes
60
+ thread.join()
61
 
62
 
63
  """