masanorihirano commited on
Commit
53f0151
·
verified ·
1 Parent(s): ac70489

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -19
app.py CHANGED
@@ -7,11 +7,12 @@ from transformers import (
7
  StoppingCriteria,
8
  StoppingCriteriaList,
9
  )
10
- from threading import Thread
 
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
13
 
14
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
15
  pipeline = transformers.pipeline(
16
  "text-generation",
17
  model="pfnet/plamo-2-1b",
@@ -35,10 +36,56 @@ class StoppingCriteriaSub(StoppingCriteria):
35
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])])
36
 
37
 
38
- async def respond(prompt, max_tokens):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  print(prompt)
40
 
41
- thread = Thread(
 
 
 
 
42
  target=pipeline,
43
  kwargs=dict(
44
  text_inputs=prompt,
@@ -50,24 +97,21 @@ async def respond(prompt, max_tokens):
50
  eos_token_id=tokenizer.eos_token_id,
51
  stopping_criteria=stopping_criteria,
52
  ),
53
- daemon=True,
54
  )
55
- thread.start()
56
-
57
  response = ""
58
 
59
- for output in streamer:
60
- if not output:
61
- await asyncio.sleep(0)
62
- continue
63
- print(output)
64
- response += output
65
- yield response, gr.update(interactive=False), gr.update(interactive=False),
66
- yield (
67
- response,
68
- gr.update(interactive=True),
69
- gr.update(interactive=True),
70
- )
71
 
72
 
73
  def reset_textbox():
 
7
  StoppingCriteria,
8
  StoppingCriteriaList,
9
  )
10
+ import threading
11
+
12
+ import ctypes
13
 
14
  tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
15
 
 
16
  pipeline = transformers.pipeline(
17
  "text-generation",
18
  model="pfnet/plamo-2-1b",
 
36
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])])
37
 
38
 
39
+ class CancelableThread(threading.Thread):
40
+ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}):
41
+ threading.Thread.__init__(self, group=group, target=target, name=name)
42
+ self.args = args
43
+ self.kwargs = kwargs
44
+ return
45
+
46
+ def run(self):
47
+ self.id = threading.get_native_id()
48
+ self._target(*self.args, **self.kwargs)
49
+
50
+ def get_id(self):
51
+ return self.id
52
+
53
+ def raise_exception(self):
54
+ thread_id = self.get_id()
55
+ resu = ctypes.pythonapi.PyThreadState_SetAsyncExc(
56
+ ctypes.c_long(thread_id), ctypes.py_object(SystemExit)
57
+ )
58
+ if resu > 1:
59
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), 0)
60
+ print("Failure in raising exception")
61
+
62
+
63
+ class ThreadManager:
64
+ def __init__(self, thread: CancelableThread, **kwargs):
65
+ self.thread = thread
66
+
67
+ def __enter__(self):
68
+ # スレッドを開始
69
+ self.thread.start()
70
+ return self.thread
71
+
72
+ def __exit__(self, exc_type, exc_value, traceback):
73
+ # スレッドの終了を待機
74
+ if self.thread.is_alive():
75
+ print("trying to terminate thread")
76
+ self.thread.raise_exception()
77
+ self.thread.join()
78
+ print("Thread has been successfully joined.")
79
+
80
+
81
+ def respond(prompt, max_tokens):
82
  print(prompt)
83
 
84
+ streamer = TextIteratorStreamer(
85
+ tokenizer, skip_prompt=True, skip_special_tokens=True
86
+ )
87
+
88
+ thread = CancelableThread(
89
  target=pipeline,
90
  kwargs=dict(
91
  text_inputs=prompt,
 
97
  eos_token_id=tokenizer.eos_token_id,
98
  stopping_criteria=stopping_criteria,
99
  ),
 
100
  )
 
 
101
  response = ""
102
 
103
+ with ThreadManager(thread=thread):
104
+ for output in streamer:
105
+ if not output:
106
+ continue
107
+ print(output)
108
+ response += output
109
+ yield response, gr.update(interactive=False), gr.update(interactive=False),
110
+ yield (
111
+ response,
112
+ gr.update(interactive=True),
113
+ gr.update(interactive=True),
114
+ )
115
 
116
 
117
  def reset_textbox():