cbensimon HF staff commited on
Commit
0d138f0
1 Parent(s): 4c663d6

Update marigold_depth_estimation_lcm.py

Browse files
Files changed (1) hide show
  1. marigold_depth_estimation_lcm.py +7 -0
marigold_depth_estimation_lcm.py CHANGED
@@ -283,6 +283,7 @@ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
283
  """
284
  Encode text embedding for empty prompt.
285
  """
 
286
  prompt = ""
287
  text_inputs = self.tokenizer(
288
  prompt,
@@ -291,8 +292,11 @@ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
291
  truncation=True,
292
  return_tensors="pt",
293
  )
 
294
  text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
 
295
  self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
 
296
 
297
  @torch.no_grad()
298
  def single_infer(
@@ -358,7 +362,10 @@ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
358
 
359
  # Batched empty text embedding
360
  if self.empty_text_embed is None:
 
361
  self._encode_empty_text()
 
 
362
  batch_empty_text_embed = self.empty_text_embed.repeat(
363
  (rgb_latent.shape[0], 1, 1)
364
  ) # [B, 2, 1024]
 
283
  """
284
  Encode text embedding for empty prompt.
285
  """
286
+ print("_encode_empty_text")
287
  prompt = ""
288
  text_inputs = self.tokenizer(
289
  prompt,
 
292
  truncation=True,
293
  return_tensors="pt",
294
  )
295
+ print(f"{self.text_encoder.device=}")
296
  text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
297
+ print(f"{text_input_ids.device=}")
298
  self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
299
+ print(f"{self.empty_text_embed.device=}", f"{self.empty_text_embed.dtype=}")
300
 
301
  @torch.no_grad()
302
  def single_infer(
 
362
 
363
  # Batched empty text embedding
364
  if self.empty_text_embed is None:
365
+ print("self.empty_text_embed is None")
366
  self._encode_empty_text()
367
+ else:
368
+ print("self.empty_text_embed is not None")
369
  batch_empty_text_embed = self.empty_text_embed.repeat(
370
  (rgb_latent.shape[0], 1, 1)
371
  ) # [B, 2, 1024]