Willem-BD commited on
Commit
962383c
·
verified ·
1 Parent(s): 8f54436
Files changed (2) hide show
  1. inferencer.py +35 -56
  2. modeling/bagel/bagel.py +15 -13
inferencer.py CHANGED
@@ -2,16 +2,10 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  from copy import deepcopy
5
- from typing import List, Dict, Tuple, Optional, Union, Any
6
- import matplotlib.pyplot as plt
7
 
8
  from PIL import Image
9
  import torch
10
- import torch.nn.functional as F
11
- from torch import nn
12
- from torch.nn.attention.flex_attention import create_block_mask
13
- from transformers.configuration_utils import PretrainedConfig
14
- from transformers.modeling_utils import PreTrainedModel
15
 
16
  from data.data_utils import pil_img2rgb
17
  from modeling.bagel.qwen2_navit import NaiveCache
@@ -196,17 +190,17 @@ class InterleaveInferencer:
196
  ropes = gen_context['ropes']
197
 
198
  generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
199
- unpacked_latent = self.model.generate_text(
200
  past_key_values=past_key_values,
201
  max_length=max_length,
202
  do_sample=do_sample,
203
  temperature=temperature,
204
  end_token_id=self.new_token_ids['eos_token_id'],
205
  **generation_input,
206
- )
207
- output = self.tokenizer.decode(unpacked_latent[:,0])
208
- output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
209
- return output
210
 
211
  @torch.no_grad()
212
  def interleave_inference(
@@ -214,10 +208,11 @@ class InterleaveInferencer:
214
  input_lists: List[Union[str, Image.Image]],
215
  think=False,
216
  understanding_output=False,
217
-
218
  max_think_token_n=1000,
219
  do_sample=False,
220
  text_temperature=0.3,
 
221
  cfg_text_scale=3.0,
222
  cfg_img_scale=1.5,
223
  cfg_interval=[0.4, 1.0],
@@ -225,23 +220,20 @@ class InterleaveInferencer:
225
  num_timesteps=50,
226
  cfg_renorm_min=0.0,
227
  cfg_renorm_type="global",
228
- image_shapes=(1024, 1024),
229
- ) -> List[Union[str, Image.Image]]:
230
-
231
- output_list = []
232
  gen_context = self.init_gen_context()
233
  cfg_text_context = deepcopy(gen_context)
234
  cfg_img_context = deepcopy(gen_context)
235
 
236
  with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
 
237
  if think:
238
- if understanding_output:
239
- system_prompt = VLM_THINK_SYSTEM_PROMPT
240
- else:
241
- system_prompt = GEN_THINK_SYSTEM_PROMPT
242
  gen_context = self.update_context_text(system_prompt, gen_context)
 
243
  cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
244
-
245
  for input_term in input_lists:
246
  if isinstance(input_term, str):
247
  cfg_text_context = deepcopy(gen_context)
@@ -251,29 +243,29 @@ class InterleaveInferencer:
251
  elif isinstance(input_term, Image.Image):
252
  input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
253
  gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
254
-
255
  image_shapes = input_term.size[::-1]
256
  cfg_text_context = deepcopy(gen_context)
257
 
258
  else:
259
  raise ValueError(f"Unsupported input type: {type(input_term)}")
260
-
261
- if understanding_output:
262
- gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
263
- output_list.append(gen_text)
264
-
265
- else:
266
  if think:
267
- gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
268
- gen_context = self.update_context_text(gen_text, gen_context)
269
- output_list.append(gen_text)
270
-
 
 
 
 
271
  img = self.gen_image(
272
- image_shapes,
273
- gen_context,
274
  cfg_text_precontext=cfg_text_context,
275
  cfg_img_precontext=cfg_img_context,
276
-
277
  cfg_text_scale=cfg_text_scale,
278
  cfg_img_scale=cfg_img_scale,
279
  cfg_interval=cfg_interval,
@@ -282,34 +274,21 @@ class InterleaveInferencer:
282
  cfg_renorm_min=cfg_renorm_min,
283
  cfg_renorm_type=cfg_renorm_type,
284
  )
 
285
 
286
- output_list.append(img)
287
-
288
- return output_list
289
-
290
  def __call__(
291
  self,
292
  image: Optional[Image.Image] = None,
293
  text: Optional[str] = None,
294
- **kargs
295
- ) -> Dict[str, Any]:
296
- output_dict = {'image': None, 'text': None}
297
-
298
- if image is None and text is None:
299
- print('Please provide at least one input: either an image or text.')
300
- return output_dict
301
-
302
  input_list = []
303
  if image is not None:
304
  input_list.append(image)
305
  if text is not None:
306
  input_list.append(text)
 
 
 
307
 
308
- output_list = self.interleave_inference(input_list, **kargs)
309
-
310
- for i in output_list:
311
- if isinstance(i, Image.Image):
312
- output_dict['image'] = i
313
- elif isinstance(i, str):
314
- output_dict['text'] = i
315
- return output_dict
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  from copy import deepcopy
5
+ from typing import List, Optional, Union, Any
 
6
 
7
  from PIL import Image
8
  import torch
 
 
 
 
 
9
 
10
  from data.data_utils import pil_img2rgb
11
  from modeling.bagel.qwen2_navit import NaiveCache
 
190
  ropes = gen_context['ropes']
191
 
192
  generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
193
+ for unpacked_latent in self.model.generate_text(
194
  past_key_values=past_key_values,
195
  max_length=max_length,
196
  do_sample=do_sample,
197
  temperature=temperature,
198
  end_token_id=self.new_token_ids['eos_token_id'],
199
  **generation_input,
200
+ ):
201
+ output = self.tokenizer.decode(unpacked_latent)
202
+ if output != "<|im_end|>":
203
+ yield output
204
 
205
  @torch.no_grad()
206
  def interleave_inference(
 
208
  input_lists: List[Union[str, Image.Image]],
209
  think=False,
210
  understanding_output=False,
211
+ # for gen_text
212
  max_think_token_n=1000,
213
  do_sample=False,
214
  text_temperature=0.3,
215
+ # for gen_image
216
  cfg_text_scale=3.0,
217
  cfg_img_scale=1.5,
218
  cfg_interval=[0.4, 1.0],
 
220
  num_timesteps=50,
221
  cfg_renorm_min=0.0,
222
  cfg_renorm_type="global",
223
+ image_shapes=(1024, 1024), # Default, can be overridden by actual input image
224
+ ):
 
 
225
  gen_context = self.init_gen_context()
226
  cfg_text_context = deepcopy(gen_context)
227
  cfg_img_context = deepcopy(gen_context)
228
 
229
  with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
230
+
231
  if think:
232
+ system_prompt = VLM_THINK_SYSTEM_PROMPT if understanding_output else GEN_THINK_SYSTEM_PROMPT
 
 
 
233
  gen_context = self.update_context_text(system_prompt, gen_context)
234
+ cfg_text_context = self.update_context_text(system_prompt, cfg_text_context)
235
  cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
236
+
237
  for input_term in input_lists:
238
  if isinstance(input_term, str):
239
  cfg_text_context = deepcopy(gen_context)
 
243
  elif isinstance(input_term, Image.Image):
244
  input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
245
  gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
 
246
  image_shapes = input_term.size[::-1]
247
  cfg_text_context = deepcopy(gen_context)
248
 
249
  else:
250
  raise ValueError(f"Unsupported input type: {type(input_term)}")
251
+
252
+ if understanding_output: # Generate text
253
+ yield from self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=text_temperature)
254
+ else: # Generate image
 
 
255
  if think:
256
+ thought_text_parts = []
257
+ for part in self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=text_temperature):
258
+ yield part # Stream the thought
259
+ thought_text_parts.append(part)
260
+ full_thought_text = "".join(thought_text_parts)
261
+ if full_thought_text: # Only update if thought was generated
262
+ gen_context = self.update_context_text(full_thought_text, gen_context)
263
+
264
  img = self.gen_image(
265
+ image_shape=image_shapes,
266
+ gen_context=gen_context,
267
  cfg_text_precontext=cfg_text_context,
268
  cfg_img_precontext=cfg_img_context,
 
269
  cfg_text_scale=cfg_text_scale,
270
  cfg_img_scale=cfg_img_scale,
271
  cfg_interval=cfg_interval,
 
274
  cfg_renorm_min=cfg_renorm_min,
275
  cfg_renorm_type=cfg_renorm_type,
276
  )
277
+ yield img
278
 
 
 
 
 
279
  def __call__(
280
  self,
281
  image: Optional[Image.Image] = None,
282
  text: Optional[str] = None,
283
+ **kargs
284
+ ) -> Any:
 
 
 
 
 
 
285
  input_list = []
286
  if image is not None:
287
  input_list.append(image)
288
  if text is not None:
289
  input_list.append(text)
290
+
291
+ if not input_list and not kargs.get('force_empty_input', False): # allow forcing for special cases if needed
292
+ return
293
 
294
+ yield from self.interleave_inference(input_list, **kargs)
 
 
 
 
 
 
 
modeling/bagel/bagel.py CHANGED
@@ -890,16 +890,21 @@ class Bagel(PreTrainedModel):
890
  temperature: float = 1.0,
891
  end_token_id: int = None,
892
  ):
 
 
 
 
 
 
 
893
  step = 0
894
- generated_sequence = []
895
  curr_tokens = packed_start_tokens
896
  while step < max_length:
897
- generated_sequence.append(curr_tokens)
898
  packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
899
  query_lens = torch.ones_like(curr_tokens)
900
  packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
901
- 0, len(key_values_lens),
902
- device=key_values_lens.device,
903
  dtype=key_values_lens.dtype
904
  )
905
 
@@ -944,12 +949,11 @@ class Bagel(PreTrainedModel):
944
  packed_query_position_ids = packed_query_position_ids + 1
945
  step += 1
946
 
 
 
947
  if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
948
  break
949
 
950
- output_device = generated_sequence[0].device
951
- return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
952
-
953
  # for evaluation
954
  @torch.no_grad()
955
  def chat(
@@ -1012,15 +1016,13 @@ class Bagel(PreTrainedModel):
1012
  if torch.is_tensor(v):
1013
  generation_input[k] = v.to(device)
1014
  with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1015
- unpacked_latent = self.generate_text(
1016
  past_key_values=past_key_values,
1017
  max_length=max_length,
1018
  do_sample=do_sample,
1019
  temperature=temperature,
1020
  end_token_id=new_token_ids['eos_token_id'],
1021
  **generation_input,
1022
- )
1023
- output = tokenizer.decode(unpacked_latent[:,0])
1024
- output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
1025
-
1026
- return output
 
890
  temperature: float = 1.0,
891
  end_token_id: int = None,
892
  ):
893
+ """
894
+ Generates text token by token in a streaming fashion.
895
+
896
+ This function is a generator that yields one token at a time. It replicates
897
+ the behavior of the original batch generation function, including the handling
898
+ of start tokens and the end-of-sequence token.
899
+ """
900
  step = 0
 
901
  curr_tokens = packed_start_tokens
902
  while step < max_length:
 
903
  packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
904
  query_lens = torch.ones_like(curr_tokens)
905
  packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
906
+ 0, len(key_values_lens),
907
+ device=key_values_lens.device,
908
  dtype=key_values_lens.dtype
909
  )
910
 
 
949
  packed_query_position_ids = packed_query_position_ids + 1
950
  step += 1
951
 
952
+ yield curr_tokens # Yield each token as it's generated
953
+
954
  if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
955
  break
956
 
 
 
 
957
  # for evaluation
958
  @torch.no_grad()
959
  def chat(
 
1016
  if torch.is_tensor(v):
1017
  generation_input[k] = v.to(device)
1018
  with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1019
+ for unpacked_latent in self.generate_text(
1020
  past_key_values=past_key_values,
1021
  max_length=max_length,
1022
  do_sample=do_sample,
1023
  temperature=temperature,
1024
  end_token_id=new_token_ids['eos_token_id'],
1025
  **generation_input,
1026
+ ):
1027
+ output = tokenizer.decode(unpacked_latent[:,0])
1028
+ yield output