LenDigLearn commited on
Commit
f4455e9
·
1 Parent(s): 87fc746

experiment with own streaming

Browse files
Files changed (2) hide show
  1. app.py +52 -9
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,12 +1,44 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
- client = InferenceClient("LemiSt/SmolLM-135M-instruct-de-merged")
 
 
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def respond(
11
  message,
12
  history: list[tuple[str, str]],
@@ -14,6 +46,7 @@ def respond(
14
  max_tokens,
15
  temperature,
16
  top_p,
 
17
  ):
18
  messages = [{"role": "system", "content": system_message}]
19
 
@@ -25,14 +58,17 @@ def respond(
25
 
26
  messages.append({"role": "user", "content": message})
27
 
28
- response = client.chat_completion(
29
- messages,
30
- max_tokens=max_tokens,
31
- stream=False,
32
- temperature=temperature,
33
- top_p=top_p).choices[0].message.content
34
 
35
- yield response
 
 
 
 
 
 
 
 
36
 
37
  """
38
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -50,6 +86,13 @@ demo = gr.ChatInterface(
50
  step=0.05,
51
  label="Top-p (nucleus sampling)",
52
  ),
 
 
 
 
 
 
 
53
  ],
54
  )
55
 
 
1
+ import queue
2
  import gradio as gr
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  """
7
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
  """
9
+ checkpoint = "LemiSt/SmolLM-135M-instruct-de-merged"
10
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
11
+ model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
12
 
13
 
14
+ class CustomIterable:
15
+ def __init__(self):
16
+ self._queue = queue.Queue() # Thread-safe queue
17
+
18
+ def put(self, item):
19
+ """Add an element to the internal queue."""
20
+ self._queue.put(item)
21
+
22
+ def end(self):
23
+ """Signal that no more elements will be added."""
24
+ self._queue.put(None) # Sentinel value to indicate the end of the queue
25
+
26
+ def __iter__(self):
27
+ """Return the iterator (self in this case)."""
28
+ return self
29
+
30
+ def __next__(self):
31
+ """Return the next element from the queue, blocking if necessary."""
32
+ try:
33
+ item = self._queue.get(block=True) # Wait for an item
34
+ except queue.Empty:
35
+ raise StopIteration
36
+
37
+ if item is None: # Sentinel value to end the iteration
38
+ raise StopIteration
39
+
40
+ return item
41
+
42
  def respond(
43
  message,
44
  history: list[tuple[str, str]],
 
46
  max_tokens,
47
  temperature,
48
  top_p,
49
+ repetition_penalty
50
  ):
51
  messages = [{"role": "system", "content": system_message}]
52
 
 
58
 
59
  messages.append({"role": "user", "content": message})
60
 
61
+ streamer = CustomIterable()
 
 
 
 
 
62
 
63
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True)
64
+ outputs = model.generate(inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, streamer=streamer)
65
+
66
+ response = ""
67
+
68
+ for token in streamer:
69
+ decoded = tokenizer.decode(token, skip_special_tokens=True)
70
+ response += decoded
71
+ yield response
72
 
73
  """
74
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
86
  step=0.05,
87
  label="Top-p (nucleus sampling)",
88
  ),
89
+ gr.Slider(
90
+ minimum=0.1,
91
+ maximum=2.0,
92
+ value=1.2,
93
+ step=0.05,
94
+ label="Repetition penalty",
95
+ ),
96
  ],
97
  )
98
 
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  huggingface_hub==0.23.2
2
  minijinja==2.2.0
 
 
 
 
1
  huggingface_hub==0.23.2
2
  minijinja==2.2.0
3
+ --extra-index-url https://download.pytorch.org/whl/cpu
4
+ torch==2.4.1
5
+ transformers==4.45.2