AC2513 commited on
Commit
7641a99
·
1 Parent(s): e4b23f9

added token processing

Browse files
Files changed (1) hide show
  1. src/app.py +45 -7
src/app.py CHANGED
@@ -94,13 +94,13 @@ def process_user_input(message: dict, max_images: int) -> list[dict]:
94
 
95
  def process_history(history: list[dict]) -> list[dict]:
96
  messages = []
97
- user_content_buffer = []
98
 
99
  for item in history:
100
  if item["role"] == "assistant":
101
- if user_content_buffer:
102
- messages.append({"role": "user", "content": user_content_buffer})
103
- user_content_buffer = []
104
 
105
  messages.append(
106
  {
@@ -110,13 +110,51 @@ def process_history(history: list[dict]) -> list[dict]:
110
  )
111
  else:
112
  content = item["content"]
113
- user_content_buffer.append(
114
  {"type": "text", "text": content}
115
  if isinstance(content, str)
116
  else {"type": "image", "url": content[0]}
117
  )
118
 
119
- if user_content_buffer:
120
- messages.append({"role": "user", "content": user_content_buffer})
121
 
122
  return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def process_history(history: list[dict]) -> list[dict]:
96
  messages = []
97
+ content_buffer = []
98
 
99
  for item in history:
100
  if item["role"] == "assistant":
101
+ if content_buffer:
102
+ messages.append({"role": "user", "content": content_buffer})
103
+ content_buffer = []
104
 
105
  messages.append(
106
  {
 
110
  )
111
  else:
112
  content = item["content"]
113
+ content_buffer.append(
114
  {"type": "text", "text": content}
115
  if isinstance(content, str)
116
  else {"type": "image", "url": content[0]}
117
  )
118
 
119
+ if content_buffer:
120
+ messages.append({"role": "user", "content": content_buffer})
121
 
122
  return messages
123
+
124
+
125
+ @spaces.GPU(duration=120)
126
+ def run(
127
+ message: dict, history: list[dict], system_prompt: str, max_new_tokens: int = 512
128
+ ) -> Iterator[str]:
129
+
130
+ messages = []
131
+ if system_prompt:
132
+ messages.append(
133
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]}
134
+ )
135
+ messages.extend(process_history(history))
136
+ messages.append({"role": "user", "content": process_user_input(message)})
137
+
138
+ inputs = input_processor.apply_chat_template(
139
+ messages,
140
+ add_generation_prompt=True,
141
+ tokenize=True,
142
+ return_dict=True,
143
+ return_tensors="pt",
144
+ ).to(device=model.device, dtype=torch.bfloat16)
145
+
146
+ streamer = TextIteratorStreamer(
147
+ input_processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True
148
+ )
149
+ generate_kwargs = dict(
150
+ inputs,
151
+ streamer=streamer,
152
+ max_new_tokens=max_new_tokens,
153
+ )
154
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
155
+ t.start()
156
+
157
+ output = ""
158
+ for delta in streamer:
159
+ output += delta
160
+ yield output