LenDigLearn commited on
Commit
031ecb9
·
1 Parent(s): f4455e9

threading fix

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import queue
2
  import gradio as gr
3
  import torch
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  """
@@ -14,10 +15,14 @@ model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloa
14
  class CustomIterable:
15
  def __init__(self):
16
  self._queue = queue.Queue() # Thread-safe queue
 
17
 
18
  def put(self, item):
19
  """Add an element to the internal queue."""
20
- self._queue.put(item)
 
 
 
21
 
22
  def end(self):
23
  """Signal that no more elements will be added."""
@@ -61,8 +66,8 @@ def respond(
61
  streamer = CustomIterable()
62
 
63
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True)
64
- outputs = model.generate(inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, streamer=streamer)
65
-
66
  response = ""
67
 
68
  for token in streamer:
@@ -70,6 +75,8 @@ def respond(
70
  response += decoded
71
  yield response
72
 
 
 
73
  """
74
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
75
  """
 
1
  import queue
2
  import gradio as gr
3
  import torch
4
+ import threading
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
  """
 
15
  class CustomIterable:
16
  def __init__(self):
17
  self._queue = queue.Queue() # Thread-safe queue
18
+ self.first = True
19
 
20
  def put(self, item):
21
  """Add an element to the internal queue."""
22
+ if self.first:
23
+ self.first = False
24
+ else:
25
+ self._queue.put(item)
26
 
27
  def end(self):
28
  """Signal that no more elements will be added."""
 
66
  streamer = CustomIterable()
67
 
68
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True)
69
+ thread = threading.Thread(target=model.generate, args=([inputs]), kwargs={"max_new_tokens": max_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty, "streamer": streamer})
70
+ thread.start()
71
  response = ""
72
 
73
  for token in streamer:
 
75
  response += decoded
76
  yield response
77
 
78
+ thread.join()
79
+
80
  """
81
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
82
  """