Mizukiluke
commited on
Fix the function of chat to support the gradio demo
Browse files- modeling_mplugowl3.py +20 -31
modeling_mplugowl3.py
CHANGED
@@ -142,7 +142,6 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
142 |
media_offset=None,
|
143 |
attention_mask=None,
|
144 |
tokenizer=None,
|
145 |
-
return_vision_hidden_states=False,
|
146 |
stream=False,
|
147 |
decode_text=False,
|
148 |
**kwargs
|
@@ -156,9 +155,6 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
156 |
result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
|
157 |
else:
|
158 |
result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
|
159 |
-
|
160 |
-
if return_vision_hidden_states:
|
161 |
-
return result, image_embeds
|
162 |
|
163 |
return result
|
164 |
|
@@ -166,10 +162,9 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
166 |
self,
|
167 |
images,
|
168 |
videos,
|
169 |
-
|
170 |
tokenizer,
|
171 |
processor=None,
|
172 |
-
vision_hidden_states=None,
|
173 |
max_new_tokens=2048,
|
174 |
min_new_tokens=0,
|
175 |
sampling=True,
|
@@ -180,21 +175,23 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
180 |
use_image_id=None,
|
181 |
**kwargs
|
182 |
):
|
183 |
-
print(
|
|
|
|
|
|
|
|
|
184 |
if processor is None:
|
185 |
if self.processor is None:
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
inputs
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
max_length=max_inp_length
|
197 |
-
).to(self.device)
|
198 |
|
199 |
if sampling:
|
200 |
generation_config = {
|
@@ -202,12 +199,12 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
202 |
"top_k": 100,
|
203 |
"temperature": 0.7,
|
204 |
"do_sample": True,
|
205 |
-
"repetition_penalty": 1.05
|
206 |
}
|
207 |
else:
|
208 |
generation_config = {
|
209 |
"num_beams": 3,
|
210 |
-
"repetition_penalty": 1.2,
|
211 |
}
|
212 |
|
213 |
if min_new_tokens > 0:
|
@@ -216,14 +213,10 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
216 |
generation_config.update(
|
217 |
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
|
218 |
)
|
219 |
-
|
220 |
-
inputs.pop("image_sizes")
|
221 |
with torch.inference_mode():
|
222 |
res = self.generate(
|
223 |
**inputs,
|
224 |
-
tokenizer=tokenizer,
|
225 |
-
max_new_tokens=max_new_tokens,
|
226 |
-
vision_hidden_states=vision_hidden_states,
|
227 |
stream=stream,
|
228 |
decode_text=True,
|
229 |
**generation_config
|
@@ -238,9 +231,5 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
238 |
return stream_gen()
|
239 |
|
240 |
else:
|
241 |
-
|
242 |
-
answer = res
|
243 |
-
else:
|
244 |
-
answer = res[0]
|
245 |
return answer
|
246 |
-
|
|
|
142 |
media_offset=None,
|
143 |
attention_mask=None,
|
144 |
tokenizer=None,
|
|
|
145 |
stream=False,
|
146 |
decode_text=False,
|
147 |
**kwargs
|
|
|
155 |
result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
|
156 |
else:
|
157 |
result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
|
|
|
|
|
|
|
158 |
|
159 |
return result
|
160 |
|
|
|
162 |
self,
|
163 |
images,
|
164 |
videos,
|
165 |
+
messages,
|
166 |
tokenizer,
|
167 |
processor=None,
|
|
|
168 |
max_new_tokens=2048,
|
169 |
min_new_tokens=0,
|
170 |
sampling=True,
|
|
|
175 |
use_image_id=None,
|
176 |
**kwargs
|
177 |
):
|
178 |
+
print(messages)
|
179 |
+
if len(images)>1:
|
180 |
+
cut_flag=False
|
181 |
+
else:
|
182 |
+
cut_flag=True
|
183 |
if processor is None:
|
184 |
if self.processor is None:
|
185 |
+
processor = self.init_processor(tokenizer)
|
186 |
+
else:
|
187 |
+
processor = self.processor
|
188 |
+
inputs = processor(messages, images=images, videos=videos, cut_enable=cut_flag)
|
189 |
+
inputs.to('cuda')
|
190 |
+
inputs.update({
|
191 |
+
'tokenizer': tokenizer,
|
192 |
+
'max_new_tokens': max_new_tokens,
|
193 |
+
# 'stream':True,
|
194 |
+
})
|
|
|
|
|
195 |
|
196 |
if sampling:
|
197 |
generation_config = {
|
|
|
199 |
"top_k": 100,
|
200 |
"temperature": 0.7,
|
201 |
"do_sample": True,
|
202 |
+
# "repetition_penalty": 1.05
|
203 |
}
|
204 |
else:
|
205 |
generation_config = {
|
206 |
"num_beams": 3,
|
207 |
+
# "repetition_penalty": 1.2,
|
208 |
}
|
209 |
|
210 |
if min_new_tokens > 0:
|
|
|
213 |
generation_config.update(
|
214 |
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
|
215 |
)
|
216 |
+
print(inputs)
|
|
|
217 |
with torch.inference_mode():
|
218 |
res = self.generate(
|
219 |
**inputs,
|
|
|
|
|
|
|
220 |
stream=stream,
|
221 |
decode_text=True,
|
222 |
**generation_config
|
|
|
231 |
return stream_gen()
|
232 |
|
233 |
else:
|
234 |
+
answer = res[0]
|
|
|
|
|
|
|
235 |
return answer
|
|