Fabrice-TIERCELIN commited on
Commit
b4a8d6d
·
verified ·
1 Parent(s): a823da9

" rather than '

Browse files
Files changed (1) hide show
  1. sample_video.py +4 -4
sample_video.py CHANGED
@@ -17,7 +17,7 @@ def main():
17
  raise ValueError(f"`models_root` not exists: {models_root_path}")
18
 
19
  # Create save folder to save the samples
20
- save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
21
  if not os.path.exists(args.save_path):
22
  os.makedirs(save_path, exist_ok=True)
23
 
@@ -43,16 +43,16 @@ def main():
43
  batch_size=args.batch_size,
44
  embedded_guidance_scale=args.embedded_cfg_scale
45
  )
46
- samples = outputs['samples']
47
 
48
  # Save samples
49
- if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
50
  for i, sample in enumerate(samples):
51
  sample = samples[i].unsqueeze(0)
52
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
53
  save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
54
  save_videos_grid(sample, save_path, fps=24)
55
- logger.info(f'Sample save to: {save_path}')
56
 
57
  if __name__ == "__main__":
58
  main()
 
17
  raise ValueError(f"`models_root` not exists: {models_root_path}")
18
 
19
  # Create save folder to save the samples
20
+ save_path = args.save_path if args.save_path_suffix=="" else f"{args.save_path}_{args.save_path_suffix}"
21
  if not os.path.exists(args.save_path):
22
  os.makedirs(save_path, exist_ok=True)
23
 
 
43
  batch_size=args.batch_size,
44
  embedded_guidance_scale=args.embedded_cfg_scale
45
  )
46
+ samples = outputs["samples"]
47
 
48
  # Save samples
49
+ if "LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0:
50
  for i, sample in enumerate(samples):
51
  sample = samples[i].unsqueeze(0)
52
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
53
  save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
54
  save_videos_grid(sample, save_path, fps=24)
55
+ logger.info(f"Sample save to: {save_path}")
56
 
57
  if __name__ == "__main__":
58
  main()