pandaphd commited on
Commit
cc3773d
·
1 Parent(s): c0c90ad
README.md DELETED
@@ -1,14 +0,0 @@
1
- ---
2
- title: Generative Photography
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.20.0
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-nc-nd-4.0
11
- short_description: Demo for Generative Photography
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,11 +1,28 @@
 
1
  import gradio as gr
2
  import json
3
  import torch
 
 
4
  from inference_bokehK import load_models as load_bokeh_models, run_inference as run_bokeh_inference, OmegaConf
5
  from inference_focal_length import load_models as load_focal_models, run_inference as run_focal_inference
6
  from inference_shutter_speed import load_models as load_shutter_models, run_inference as run_shutter_inference
7
  from inference_color_temperature import load_models as load_color_models, run_inference as run_color_inference
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  torch.manual_seed(42)
10
 
11
  bokeh_cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_bokehK.yaml")
 
1
+ import os
2
  import gradio as gr
3
  import json
4
  import torch
5
+ from huggingface_hub import snapshot_download
6
+
7
  from inference_bokehK import load_models as load_bokeh_models, run_inference as run_bokeh_inference, OmegaConf
8
  from inference_focal_length import load_models as load_focal_models, run_inference as run_focal_inference
9
  from inference_shutter_speed import load_models as load_shutter_models, run_inference as run_shutter_inference
10
  from inference_color_temperature import load_models as load_color_models, run_inference as run_color_inference
11
 
12
+
13
+
14
+
15
+ model_path = "ckpts"
16
+ os.makedirs(model_path, exist_ok=True)
17
+
18
+
19
+ print("Downloading models from Hugging Face...")
20
+ snapshot_download(repo_id="pandaphd/generative_photography", local_dir=model_path)
21
+
22
+
23
+
24
+
25
+
26
  torch.manual_seed(42)
27
 
28
  bokeh_cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_bokehK.yaml")
app_bokehK.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import json
4
+ from inference_bokehK import load_models, run_inference, OmegaConf
5
+ import torch
6
+
7
+ # Initialize models once at startup
8
+ cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_bokehK.yaml")
9
+ pipeline, device = load_models(cfg)
10
+
11
+ def generate_video(base_scene, bokehK_list):
12
+ try:
13
+ # Validate input
14
+ if len(json.loads(bokehK_list)) != 5:
15
+ raise ValueError("Exactly 5 Bokeh K values required")
16
+
17
+ # Run inference
18
+ video_path = run_inference(
19
+ pipeline=pipeline,
20
+ tokenizer=pipeline.tokenizer,
21
+ text_encoder=pipeline.text_encoder,
22
+ base_scene=base_scene,
23
+ bokehK_list=bokehK_list,
24
+ device=device
25
+ )
26
+ return video_path
27
+
28
+ except Exception as e:
29
+ raise gr.Error(f"Generation failed: {str(e)}")
30
+
31
+ # Example inputs
32
+ examples = [
33
+ [
34
+ "A young boy wearing an orange jacket is standing on a crosswalk, waiting to cross the street.",
35
+ "[2.5, 6.3, 10.1, 17.2, 24.0]"
36
+ ],
37
+ [
38
+ "A display of frozen desserts, including cupcakes and donuts, is arranged in a row on a counter.",
39
+ "[20.0, 18.5, 15.0, 10.5, 5.0]"
40
+ ]
41
+ ]
42
+
43
+ with gr.Blocks(title="Bokeh Effect Generator") as demo:
44
+ gr.Markdown("#Dynamic Bokeh Effect Generation")
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ scene_input = gr.Textbox(
49
+ label="Scene Description",
50
+ placeholder="Describe the scene you want to generate..."
51
+ )
52
+ bokeh_input = gr.Textbox(
53
+ label="Bokeh Blur Values",
54
+ placeholder="Enter 5 comma-separated values from 1-30 (e.g., [2.44, 8.3, 10.1, 17.2, 24.0])"
55
+ )
56
+ submit_btn = gr.Button("Generate Video", variant="primary")
57
+
58
+ with gr.Column():
59
+ video_output = gr.Video(label="Generated Video")
60
+ error_output = gr.Textbox(label="Error Messages", visible=False)
61
+
62
+ gr.Examples(
63
+ examples=examples,
64
+ inputs=[scene_input, bokeh_input],
65
+ outputs=[video_output],
66
+ fn=generate_video,
67
+ cache_examples=True
68
+ )
69
+
70
+ submit_btn.click(
71
+ fn=generate_video,
72
+ inputs=[scene_input, bokeh_input],
73
+ outputs=[video_output],
74
+ )
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch(share=True)
app_color_temperature.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import json
4
+ from inference_color_temperature import load_models, run_inference, OmegaConf
5
+ import torch
6
+
7
+ # Initialize models once at startup
8
+ cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_color_temperature.yaml")
9
+ pipeline, device = load_models(cfg)
10
+
11
+ def generate_video(base_scene, color_temperature_list):
12
+ try:
13
+ # Validate input
14
+ if len(json.loads(color_temperature_list)) != 5:
15
+ raise ValueError("Exactly 5 color_temperature values required")
16
+
17
+ # Run inference
18
+ video_path = run_inference(
19
+ pipeline=pipeline,
20
+ tokenizer=pipeline.tokenizer,
21
+ text_encoder=pipeline.text_encoder,
22
+ base_scene=base_scene,
23
+ color_temperature_list=color_temperature_list,
24
+ device=device
25
+ )
26
+ return video_path
27
+
28
+ except Exception as e:
29
+ raise gr.Error(f"Generation failed: {str(e)}")
30
+
31
+ # Example inputs
32
+ examples = [
33
+ [
34
+ "A beautiful blue sky with a mountain range in the background.",
35
+ "[5455.0, 5155.0, 5555.0, 6555.0, 7555.0]"
36
+ ],
37
+ [
38
+ "A red couch is situated in front of a window, which is filled with a variety of potted plants.",
39
+ "[3500.0, 5500.0, 6500.0, 7500.0, 8500.0]"
40
+ ]
41
+ ]
42
+
43
+ with gr.Blocks(title="Color Temperature Effect Generator") as demo:
44
+ gr.Markdown("# Dynamic Color Temperature Effect Generation")
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ scene_input = gr.Textbox(
49
+ label="Scene Description",
50
+ placeholder="Describe the scene you want to generate..."
51
+ )
52
+ color_temperature_input = gr.Textbox(
53
+ label="Color Temperature Values",
54
+ placeholder="Enter 5 comma-separated values from 2000-10000 (e.g., [3001.3, 4000.2, 4400.34, 5488.23, 8888.82])"
55
+ )
56
+ submit_btn = gr.Button("Generate Video", variant="primary")
57
+
58
+ with gr.Column():
59
+ video_output = gr.Video(label="Generated Video")
60
+ error_output = gr.Textbox(label="Error Messages", visible=False)
61
+
62
+ gr.Examples(
63
+ examples=examples,
64
+ inputs=[scene_input, color_temperature_input],
65
+ outputs=[video_output],
66
+ fn=generate_video,
67
+ cache_examples=True
68
+ )
69
+
70
+ submit_btn.click(
71
+ fn=generate_video,
72
+ inputs=[scene_input, color_temperature_input],
73
+ outputs=[video_output],
74
+ )
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch(share=True)
app_focal_length.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import json
4
+ from inference_focal_length import load_models, run_inference, OmegaConf
5
+ import torch
6
+
7
+ # Initialize models once at startup
8
+ cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_focal_length.yaml")
9
+ pipeline, device = load_models(cfg)
10
+
11
+ def generate_video(base_scene, focal_length_list):
12
+ try:
13
+ # Validate input
14
+ if len(json.loads(focal_length_list)) != 5:
15
+ raise ValueError("Exactly 5 focal_length values required")
16
+
17
+ # Run inference
18
+ video_path = run_inference(
19
+ pipeline=pipeline,
20
+ tokenizer=pipeline.tokenizer,
21
+ text_encoder=pipeline.text_encoder,
22
+ base_scene=base_scene,
23
+ focal_length_list=focal_length_list,
24
+ device=device
25
+ )
26
+ return video_path
27
+
28
+ except Exception as e:
29
+ raise gr.Error(f"Generation failed: {str(e)}")
30
+
31
+ # Example inputs
32
+ examples = [
33
+ [
34
+ "A small office cubicle with a desk, computer, and chair.",
35
+ "[25.1, 36.1, 47.1, 58.1, 69.1]"
36
+ ],
37
+ [
38
+ "A large, white couch is placed in a living room, with a mirror above it. The couch is covered with various items, including a blue box, a pink towel, and a pair of shoes.",
39
+ "[55.0, 46.0, 37.0, 28.0, 25.0]"
40
+ ]
41
+ ]
42
+
43
+ with gr.Blocks(title="Focal Length Effect Generator") as demo:
44
+ gr.Markdown("#Dynamic Focal Length Effect Generation")
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ scene_input = gr.Textbox(
49
+ label="Scene Description",
50
+ placeholder="Describe the scene you want to generate..."
51
+ )
52
+ focal_length_input = gr.Textbox(
53
+ label="Focal Length Values",
54
+ placeholder="Enter 5 comma-separated values from 24-70 (e.g., [25.1, 30.2, 33.3, 40.8, 54.0])"
55
+ )
56
+ submit_btn = gr.Button("Generate Video", variant="primary")
57
+
58
+ with gr.Column():
59
+ video_output = gr.Video(label="Generated Video")
60
+ error_output = gr.Textbox(label="Error Messages", visible=False)
61
+
62
+ gr.Examples(
63
+ examples=examples,
64
+ inputs=[scene_input, focal_length_input],
65
+ outputs=[video_output],
66
+ fn=generate_video,
67
+ cache_examples=True
68
+ )
69
+
70
+ submit_btn.click(
71
+ fn=generate_video,
72
+ inputs=[scene_input, focal_length_input],
73
+ outputs=[video_output],
74
+ )
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch(share=True)
app_shutter_speed.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import json
4
+ from inference_shutter_speed import load_models, run_inference, OmegaConf
5
+ import torch
6
+
7
+ # Initialize models once at startup
8
+ cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_shutter_speed.yaml")
9
+ pipeline, device = load_models(cfg)
10
+
11
+ def generate_video(base_scene, shutter_speed_list):
12
+ try:
13
+ # Validate input
14
+ if len(json.loads(shutter_speed_list)) != 5:
15
+ raise ValueError("Exactly 5 shutter_speed values required")
16
+
17
+ # Run inference
18
+ video_path = run_inference(
19
+ pipeline=pipeline,
20
+ tokenizer=pipeline.tokenizer,
21
+ text_encoder=pipeline.text_encoder,
22
+ base_scene=base_scene,
23
+ shutter_speed_list=shutter_speed_list,
24
+ device=device
25
+ )
26
+ return video_path
27
+
28
+ except Exception as e:
29
+ raise gr.Error(f"Generation failed: {str(e)}")
30
+
31
+ # Example inputs
32
+ examples = [
33
+ [
34
+ "A brown and orange leather handbag with a paw print on it sits next to a book.",
35
+ "[0.11, 0.22, 0.33, 0.44, 0.55]"
36
+ ],
37
+ [
38
+ "A variety of potted plants are displayed on a windowsill, with some of them placed in yellow and white bowls. ",
39
+ "[0.29, 0.49, 0.69, 0.79, 0.89]"
40
+ ]
41
+ ]
42
+
43
+ with gr.Blocks(title="Shutter Speed Effect Generator") as demo:
44
+ gr.Markdown("#Dynamic Shutter Speed Effect Generation")
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ scene_input = gr.Textbox(
49
+ label="Scene Description",
50
+ placeholder="Describe the scene you want to generate..."
51
+ )
52
+ shutter_speed_input = gr.Textbox(
53
+ label="Shutter Speed Values",
54
+ placeholder="Enter 5 comma-separated values from 0.1-1.0 (e.g., [0.15, 0.32, 0.53, 0.62, 0.82])"
55
+ )
56
+ submit_btn = gr.Button("Generate Video", variant="primary")
57
+
58
+ with gr.Column():
59
+ video_output = gr.Video(label="Generated Video")
60
+ error_output = gr.Textbox(label="Error Messages", visible=False)
61
+
62
+ gr.Examples(
63
+ examples=examples,
64
+ inputs=[scene_input, shutter_speed_input],
65
+ outputs=[video_output],
66
+ fn=generate_video,
67
+ cache_examples=True
68
+ )
69
+
70
+ submit_btn.click(
71
+ fn=generate_video,
72
+ inputs=[scene_input, shutter_speed_input],
73
+ outputs=[video_output],
74
+ )
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch(share=True)
configs/inference_genphoto/adv3_256_384_genphoto_relora_bokehK.yaml CHANGED
@@ -1,13 +1,11 @@
1
- output_dir: "inference_output/genphoto_bokehK"
2
 
3
- pretrained_model_repo: "pandaphd/generative_photography"
4
- pretrained_model_path: "stable-diffusion-v1-5"
5
 
 
6
  unet_subfolder: "unet_merged"
 
 
 
7
 
8
- camera_adaptor_ckpt: "weights/checkpoint-bokehK.ckpt"
9
- lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
- motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
 
12
  lora_rank: 2
13
  lora_scale: 1.0
@@ -43,7 +41,6 @@ camera_encoder_kwargs:
43
  attention_block_types: ["Temporal_Self", ]
44
  temporal_position_encoding: true
45
  temporal_position_encoding_max_len: 16
46
-
47
  attention_processor_kwargs:
48
  add_spatial: false
49
  spatial_attn_names: 'attn1'
@@ -53,7 +50,6 @@ attention_processor_kwargs:
53
  query_condition: true
54
  key_value_condition: true
55
  scale: 1.0
56
-
57
  noise_scheduler_kwargs:
58
  num_train_timesteps: 1000
59
  beta_start: 0.00085
@@ -62,5 +58,6 @@ noise_scheduler_kwargs:
62
  steps_offset: 1
63
  clip_sample: false
64
 
 
65
  num_workers: 8
66
  global_seed: 42
 
 
1
 
 
 
2
 
3
+ pretrained_model_path: "./ckpts/stable-diffusion-v1-5/"
4
  unet_subfolder: "unet_merged"
5
+ camera_adaptor_ckpt: "./ckpts/weights/checkpoint-bokehK.ckpt"
6
+ lora_ckpt: "./ckpts/weights/RealEstate10K_LoRA.ckpt"
7
+ motion_module_ckpt: "./ckpts/weights/v3_sd15_mm.ckpt"
8
 
 
 
 
9
 
10
  lora_rank: 2
11
  lora_scale: 1.0
 
41
  attention_block_types: ["Temporal_Self", ]
42
  temporal_position_encoding: true
43
  temporal_position_encoding_max_len: 16
 
44
  attention_processor_kwargs:
45
  add_spatial: false
46
  spatial_attn_names: 'attn1'
 
50
  query_condition: true
51
  key_value_condition: true
52
  scale: 1.0
 
53
  noise_scheduler_kwargs:
54
  num_train_timesteps: 1000
55
  beta_start: 0.00085
 
58
  steps_offset: 1
59
  clip_sample: false
60
 
61
+
62
  num_workers: 8
63
  global_seed: 42
configs/inference_genphoto/adv3_256_384_genphoto_relora_color_temperature.yaml CHANGED
@@ -1,16 +1,13 @@
1
  output_dir: "inference_output/genphoto_color_temperature"
2
-
3
- pretrained_model_repo: "pandaphd/generative_photography"
4
- pretrained_model_path: "stable-diffusion-v1-5"
5
-
6
  unet_subfolder: "unet_merged"
7
 
8
- camera_adaptor_ckpt: "weights/checkpoint-color_temperature.ckpt"
9
- lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
- motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
 
12
  lora_rank: 2
13
  lora_scale: 1.0
 
 
14
  motion_lora_rank: 0
15
  motion_lora_scale: 1.0
16
 
 
1
  output_dir: "inference_output/genphoto_color_temperature"
2
+ pretrained_model_path: "./ckpts/stable-diffusion-v1-5/"
 
 
 
3
  unet_subfolder: "unet_merged"
4
 
5
+ camera_adaptor_ckpt: "./ckpts/weights/checkpoint-color_temperature.ckpt"
 
 
6
 
7
  lora_rank: 2
8
  lora_scale: 1.0
9
+ lora_ckpt: "./ckpts/weights/RealEstate10K_LoRA.ckpt"
10
+ motion_module_ckpt: "./ckpts/weights/v3_sd15_mm.ckpt"
11
  motion_lora_rank: 0
12
  motion_lora_scale: 1.0
13
 
configs/inference_genphoto/adv3_256_384_genphoto_relora_focal_length.yaml CHANGED
@@ -1,16 +1,14 @@
1
  output_dir: "inference_output/genphoto_focal_length"
2
-
3
- pretrained_model_repo: "pandaphd/generative_photography"
4
- pretrained_model_path: "stable-diffusion-v1-5"
5
-
6
  unet_subfolder: "unet_merged"
7
 
8
- camera_adaptor_ckpt: "weights/checkpoint-focal_length.ckpt"
9
- lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
- motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
 
12
  lora_rank: 2
13
  lora_scale: 1.0
 
 
14
  motion_lora_rank: 0
15
  motion_lora_scale: 1.0
16
 
 
1
  output_dir: "inference_output/genphoto_focal_length"
2
+ pretrained_model_path: "./ckpts/stable-diffusion-v1-5/"
 
 
 
3
  unet_subfolder: "unet_merged"
4
 
5
+ camera_adaptor_ckpt: "./ckpts/weights/checkpoint-focal_length.ckpt"
6
+
 
7
 
8
  lora_rank: 2
9
  lora_scale: 1.0
10
+ lora_ckpt: "./ckpts/weights/RealEstate10K_LoRA.ckpt"
11
+ motion_module_ckpt: "./ckpts/weights/v3_sd15_mm.ckpt"
12
  motion_lora_rank: 0
13
  motion_lora_scale: 1.0
14
 
configs/inference_genphoto/adv3_256_384_genphoto_relora_shutter_speed.yaml CHANGED
@@ -1,16 +1,13 @@
1
  output_dir: "inference_output/genphoto_shutter_speed"
2
-
3
- pretrained_model_repo: "pandaphd/generative_photography"
4
- pretrained_model_path: "stable-diffusion-v1-5"
5
-
6
  unet_subfolder: "unet_merged"
7
 
8
- camera_adaptor_ckpt: "weights/checkpoint-shutter_speed.ckpt"
9
- lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
- motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
 
12
  lora_rank: 2
13
  lora_scale: 1.0
 
 
14
  motion_lora_rank: 0
15
  motion_lora_scale: 1.0
16
 
 
1
  output_dir: "inference_output/genphoto_shutter_speed"
2
+ pretrained_model_path: "./ckpts/stable-diffusion-v1-5/"
 
 
 
3
  unet_subfolder: "unet_merged"
4
 
5
+ camera_adaptor_ckpt: "./ckpts/weights/checkpoint-shutter_speed.ckpt"
 
 
6
 
7
  lora_rank: 2
8
  lora_scale: 1.0
9
+ lora_ckpt: "./ckpts/weights/RealEstate10K_LoRA.ckpt"
10
+ motion_module_ckpt: "./ckpts/weights/v3_sd15_mm.ckpt"
11
  motion_lora_rank: 0
12
  motion_lora_scale: 1.0
13
 
inference_bokehK.py CHANGED
@@ -22,11 +22,6 @@ from genphoto.utils.util import save_videos_grid
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
-
26
- from huggingface_hub import hf_hub_download
27
-
28
-
29
-
30
  def create_bokehK_embedding(bokehK_values, target_height, target_width):
31
  f = bokehK_values.shape[0]
32
  bokehK_embedding = torch.zeros((f, 3, target_height, target_width), dtype=bokehK_values.dtype)
@@ -94,24 +89,18 @@ class Camera_Embedding(Dataset):
94
  camera_embedding = torch.cat((bokehK_embedding, ccl_embedding), dim=1)
95
  return camera_embedding
96
 
 
97
  def load_models(cfg):
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
 
100
- pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
101
- lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
102
- motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
103
- camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-bokehK.ckpt")
104
-
105
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
106
- vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
107
  vae.requires_grad_(False)
108
-
109
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
110
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
111
  text_encoder.requires_grad_(False)
112
-
113
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
114
- pretrained_model_path,
115
  subfolder=cfg.unet_subfolder,
116
  unet_additional_kwargs=cfg.unet_additional_kwargs
117
  ).to(device)
@@ -132,26 +121,26 @@ def load_models(cfg):
132
  )
133
 
134
  if cfg.lora_ckpt is not None:
135
- lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
136
  if 'lora_state_dict' in lora_checkpoints.keys():
137
  lora_checkpoints = lora_checkpoints['lora_state_dict']
138
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
139
  assert len(lora_u) == 0
140
 
141
  if cfg.motion_module_ckpt is not None:
142
- mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
143
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
144
  assert len(mm_u) == 0
145
-
146
  if cfg.camera_adaptor_ckpt is not None:
147
- camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
148
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
149
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
150
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
151
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
152
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
153
  assert len(attention_processor_u) == 0
154
-
155
  pipeline = GenPhotoPipeline(
156
  vae=vae,
157
  text_encoder=text_encoder,
@@ -160,10 +149,12 @@ def load_models(cfg):
160
  scheduler=noise_scheduler,
161
  camera_encoder=camera_encoder
162
  ).to(device)
163
-
164
  pipeline.enable_vae_slicing()
 
165
  return pipeline, device
166
 
 
 
167
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, bokehK_list, device, video_length=5, height=256, width=384):
168
 
169
 
 
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
 
 
 
 
 
25
  def create_bokehK_embedding(bokehK_values, target_height, target_width):
26
  f = bokehK_values.shape[0]
27
  bokehK_embedding = torch.zeros((f, 3, target_height, target_width), dtype=bokehK_values.dtype)
 
89
  camera_embedding = torch.cat((bokehK_embedding, ccl_embedding), dim=1)
90
  return camera_embedding
91
 
92
+
93
  def load_models(cfg):
94
  device = "cuda" if torch.cuda.is_available() else "cpu"
95
 
 
 
 
 
 
96
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
97
+ vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
98
  vae.requires_grad_(False)
99
+ tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
100
+ text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
 
101
  text_encoder.requires_grad_(False)
 
102
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
103
+ cfg.pretrained_model_path,
104
  subfolder=cfg.unet_subfolder,
105
  unet_additional_kwargs=cfg.unet_additional_kwargs
106
  ).to(device)
 
121
  )
122
 
123
  if cfg.lora_ckpt is not None:
124
+ lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
125
  if 'lora_state_dict' in lora_checkpoints.keys():
126
  lora_checkpoints = lora_checkpoints['lora_state_dict']
127
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
128
  assert len(lora_u) == 0
129
 
130
  if cfg.motion_module_ckpt is not None:
131
+ mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
132
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
133
  assert len(mm_u) == 0
134
+
135
  if cfg.camera_adaptor_ckpt is not None:
136
+ camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
137
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
138
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
139
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
140
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
141
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
142
  assert len(attention_processor_u) == 0
143
+
144
  pipeline = GenPhotoPipeline(
145
  vae=vae,
146
  text_encoder=text_encoder,
 
149
  scheduler=noise_scheduler,
150
  camera_encoder=camera_encoder
151
  ).to(device)
 
152
  pipeline.enable_vae_slicing()
153
+
154
  return pipeline, device
155
 
156
+
157
+
158
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, bokehK_list, device, video_length=5, height=256, width=384):
159
 
160
 
inference_color_temperature.py CHANGED
@@ -22,7 +22,6 @@ from genphoto.utils.util import save_videos_grid
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
- from huggingface_hub import hf_hub_download
26
 
27
 
28
  def kelvin_to_rgb(kelvin):
@@ -132,104 +131,19 @@ class Camera_Embedding(Dataset):
132
  camera_embedding = torch.cat((color_temperature_embedding, ccl_embedding), dim=1)
133
  return camera_embedding
134
 
135
- #
136
- # def load_models(cfg):
137
- #
138
- # device = "cuda" if torch.cuda.is_available() else "cpu"
139
- #
140
- # noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
141
- # vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
142
- # vae.requires_grad_(False)
143
- # tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
144
- # text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
145
- # text_encoder.requires_grad_(False)
146
- # unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
147
- # cfg.pretrained_model_path,
148
- # subfolder=cfg.unet_subfolder,
149
- # unet_additional_kwargs=cfg.unet_additional_kwargs
150
- # ).to(device)
151
- # unet.requires_grad_(False)
152
- #
153
- # camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
154
- # camera_encoder.requires_grad_(False)
155
- # camera_adaptor = CameraAdaptor(unet, camera_encoder)
156
- # camera_adaptor.requires_grad_(False)
157
- # camera_adaptor.to(device)
158
- #
159
- # logger.info("Setting the attention processors")
160
- # unet.set_all_attn_processor(
161
- # add_spatial_lora=cfg.lora_ckpt is not None,
162
- # add_motion_lora=cfg.motion_lora_rank > 0,
163
- # lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
164
- # motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
165
- # **cfg.attention_processor_kwargs
166
- # )
167
- #
168
- # if cfg.lora_ckpt is not None:
169
- # print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
170
- # lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
171
- # if 'lora_state_dict' in lora_checkpoints.keys():
172
- # lora_checkpoints = lora_checkpoints['lora_state_dict']
173
- # _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
174
- # assert len(lora_u) == 0
175
- # print(f'Loading done')
176
- #
177
- # if cfg.motion_module_ckpt is not None:
178
- # print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
179
- # mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
180
- # _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
181
- # assert len(mm_u) == 0
182
- # print("Loading done")
183
- #
184
- #
185
- # if cfg.camera_adaptor_ckpt is not None:
186
- # logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
187
- # camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
188
- # camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
189
- # attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
190
- # camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
191
- #
192
- # assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
193
- # _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
194
- # assert len(attention_processor_u) == 0
195
- #
196
- # logger.info("Camera Adaptor loading done")
197
- # else:
198
- # logger.info("No Camera Adaptor checkpoint used")
199
- #
200
- # pipeline = GenPhotoPipeline(
201
- # vae=vae,
202
- # text_encoder=text_encoder,
203
- # tokenizer=tokenizer,
204
- # unet=unet,
205
- # scheduler=noise_scheduler,
206
- # camera_encoder=camera_encoder
207
- # ).to(device)
208
- #
209
- # pipeline.enable_vae_slicing()
210
- #
211
- # return pipeline, device
212
-
213
-
214
 
215
  def load_models(cfg):
216
- device = "cuda" if torch.cuda.is_available() else "cpu"
217
 
218
- pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
219
- lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
220
- motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
221
- camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-color_temperature.ckpt")
222
 
223
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
224
- vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
225
  vae.requires_grad_(False)
226
-
227
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
228
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
229
  text_encoder.requires_grad_(False)
230
-
231
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
232
- pretrained_model_path,
233
  subfolder=cfg.unet_subfolder,
234
  unet_additional_kwargs=cfg.unet_additional_kwargs
235
  ).to(device)
@@ -241,6 +155,7 @@ def load_models(cfg):
241
  camera_adaptor.requires_grad_(False)
242
  camera_adaptor.to(device)
243
 
 
244
  unet.set_all_attn_processor(
245
  add_spatial_lora=cfg.lora_ckpt is not None,
246
  add_motion_lora=cfg.motion_lora_rank > 0,
@@ -250,25 +165,36 @@ def load_models(cfg):
250
  )
251
 
252
  if cfg.lora_ckpt is not None:
253
- lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
 
254
  if 'lora_state_dict' in lora_checkpoints.keys():
255
  lora_checkpoints = lora_checkpoints['lora_state_dict']
256
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
257
  assert len(lora_u) == 0
 
258
 
259
  if cfg.motion_module_ckpt is not None:
260
- mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
 
261
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
262
  assert len(mm_u) == 0
 
 
263
 
264
  if cfg.camera_adaptor_ckpt is not None:
265
- camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
 
266
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
267
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
268
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
 
269
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
270
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
271
  assert len(attention_processor_u) == 0
 
 
 
 
272
 
273
  pipeline = GenPhotoPipeline(
274
  vae=vae,
@@ -280,9 +206,8 @@ def load_models(cfg):
280
  ).to(device)
281
 
282
  pipeline.enable_vae_slicing()
283
- return pipeline, device
284
-
285
 
 
286
 
287
 
288
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, color_temperature_list, device, video_length=5, height=256, width=384):
 
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
 
25
 
26
 
27
  def kelvin_to_rgb(kelvin):
 
131
  camera_embedding = torch.cat((color_temperature_embedding, ccl_embedding), dim=1)
132
  return camera_embedding
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  def load_models(cfg):
 
136
 
137
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
138
 
139
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
140
+ vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
141
  vae.requires_grad_(False)
142
+ tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
143
+ text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
 
144
  text_encoder.requires_grad_(False)
 
145
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
146
+ cfg.pretrained_model_path,
147
  subfolder=cfg.unet_subfolder,
148
  unet_additional_kwargs=cfg.unet_additional_kwargs
149
  ).to(device)
 
155
  camera_adaptor.requires_grad_(False)
156
  camera_adaptor.to(device)
157
 
158
+ logger.info("Setting the attention processors")
159
  unet.set_all_attn_processor(
160
  add_spatial_lora=cfg.lora_ckpt is not None,
161
  add_motion_lora=cfg.motion_lora_rank > 0,
 
165
  )
166
 
167
  if cfg.lora_ckpt is not None:
168
+ print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
169
+ lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
170
  if 'lora_state_dict' in lora_checkpoints.keys():
171
  lora_checkpoints = lora_checkpoints['lora_state_dict']
172
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
173
  assert len(lora_u) == 0
174
+ print(f'Loading done')
175
 
176
  if cfg.motion_module_ckpt is not None:
177
+ print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
178
+ mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
179
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
180
  assert len(mm_u) == 0
181
+ print("Loading done")
182
+
183
 
184
  if cfg.camera_adaptor_ckpt is not None:
185
+ logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
186
+ camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
187
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
188
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
189
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
190
+
191
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
192
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
193
  assert len(attention_processor_u) == 0
194
+
195
+ logger.info("Camera Adaptor loading done")
196
+ else:
197
+ logger.info("No Camera Adaptor checkpoint used")
198
 
199
  pipeline = GenPhotoPipeline(
200
  vae=vae,
 
206
  ).to(device)
207
 
208
  pipeline.enable_vae_slicing()
 
 
209
 
210
+ return pipeline, device
211
 
212
 
213
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, color_temperature_list, device, video_length=5, height=256, width=384):
inference_focal_length.py CHANGED
@@ -24,9 +24,6 @@ logger = logging.getLogger(__name__)
24
 
25
 
26
 
27
- from huggingface_hub import hf_hub_download
28
-
29
-
30
 
31
  def create_focal_length_embedding(focal_length_values, target_height, target_width, base_focal_length=24.0, sensor_height=24.0, sensor_width=36.0):
32
  device = 'cpu'
@@ -137,101 +134,19 @@ class Camera_Embedding(Dataset):
137
  camera_embedding = torch.cat((focal_length_embedding, ccl_embedding), dim=1)
138
  return camera_embedding
139
 
140
- #
141
- # def load_models(cfg):
142
- #
143
- # device = "cuda" if torch.cuda.is_available() else "cpu"
144
- #
145
- # noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
146
- # vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
147
- # vae.requires_grad_(False)
148
- # tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
149
- # text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
150
- # text_encoder.requires_grad_(False)
151
- # unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
152
- # cfg.pretrained_model_path,
153
- # subfolder=cfg.unet_subfolder,
154
- # unet_additional_kwargs=cfg.unet_additional_kwargs
155
- # ).to(device)
156
- # unet.requires_grad_(False)
157
- #
158
- # camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
159
- # camera_encoder.requires_grad_(False)
160
- # camera_adaptor = CameraAdaptor(unet, camera_encoder)
161
- # camera_adaptor.requires_grad_(False)
162
- # camera_adaptor.to(device)
163
- #
164
- # logger.info("Setting the attention processors")
165
- # unet.set_all_attn_processor(
166
- # add_spatial_lora=cfg.lora_ckpt is not None,
167
- # add_motion_lora=cfg.motion_lora_rank > 0,
168
- # lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
169
- # motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
170
- # **cfg.attention_processor_kwargs
171
- # )
172
- #
173
- # if cfg.lora_ckpt is not None:
174
- # print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
175
- # lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
176
- # if 'lora_state_dict' in lora_checkpoints.keys():
177
- # lora_checkpoints = lora_checkpoints['lora_state_dict']
178
- # _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
179
- # assert len(lora_u) == 0
180
- # print(f'Loading done')
181
- #
182
- # if cfg.motion_module_ckpt is not None:
183
- # print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
184
- # mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
185
- # _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
186
- # assert len(mm_u) == 0
187
- # print("Loading done")
188
- #
189
- # if cfg.camera_adaptor_ckpt is not None:
190
- # logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
191
- # camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
192
- # camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
193
- # attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
194
- # camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
195
- #
196
- # assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
197
- # _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
198
- # assert len(attention_processor_u) == 0
199
- #
200
- # logger.info("Camera Adaptor loading done")
201
- # else:
202
- # logger.info("No Camera Adaptor checkpoint used")
203
- #
204
- # pipeline = GenPhotoPipeline(
205
- # vae=vae,
206
- # text_encoder=text_encoder,
207
- # tokenizer=tokenizer,
208
- # unet=unet,
209
- # scheduler=noise_scheduler,
210
- # camera_encoder=camera_encoder
211
- # ).to(device)
212
- # pipeline.enable_vae_slicing()
213
- #
214
- # return pipeline, device
215
-
216
 
217
  def load_models(cfg):
218
- device = "cuda" if torch.cuda.is_available() else "cpu"
219
 
220
- pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
221
- lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
222
- motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
223
- camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-focal_length.ckpt")
224
 
225
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
226
- vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
227
  vae.requires_grad_(False)
228
-
229
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
230
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
231
  text_encoder.requires_grad_(False)
232
-
233
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
234
- pretrained_model_path,
235
  subfolder=cfg.unet_subfolder,
236
  unet_additional_kwargs=cfg.unet_additional_kwargs
237
  ).to(device)
@@ -243,6 +158,7 @@ def load_models(cfg):
243
  camera_adaptor.requires_grad_(False)
244
  camera_adaptor.to(device)
245
 
 
246
  unet.set_all_attn_processor(
247
  add_spatial_lora=cfg.lora_ckpt is not None,
248
  add_motion_lora=cfg.motion_lora_rank > 0,
@@ -252,25 +168,35 @@ def load_models(cfg):
252
  )
253
 
254
  if cfg.lora_ckpt is not None:
255
- lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
 
256
  if 'lora_state_dict' in lora_checkpoints.keys():
257
  lora_checkpoints = lora_checkpoints['lora_state_dict']
258
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
259
  assert len(lora_u) == 0
 
260
 
261
  if cfg.motion_module_ckpt is not None:
262
- mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
 
263
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
264
  assert len(mm_u) == 0
265
-
 
266
  if cfg.camera_adaptor_ckpt is not None:
267
- camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
 
268
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
269
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
270
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
 
271
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
272
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
273
  assert len(attention_processor_u) == 0
 
 
 
 
274
 
275
  pipeline = GenPhotoPipeline(
276
  vae=vae,
@@ -280,10 +206,11 @@ def load_models(cfg):
280
  scheduler=noise_scheduler,
281
  camera_encoder=camera_encoder
282
  ).to(device)
283
-
284
  pipeline.enable_vae_slicing()
 
285
  return pipeline, device
286
 
 
287
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, focal_length_list, device, video_length=5, height=256, width=384):
288
 
289
  focal_length_values = json.loads(focal_length_list)
 
24
 
25
 
26
 
 
 
 
27
 
28
  def create_focal_length_embedding(focal_length_values, target_height, target_width, base_focal_length=24.0, sensor_height=24.0, sensor_width=36.0):
29
  device = 'cpu'
 
134
  camera_embedding = torch.cat((focal_length_embedding, ccl_embedding), dim=1)
135
  return camera_embedding
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  def load_models(cfg):
 
139
 
140
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
141
 
142
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
143
+ vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
144
  vae.requires_grad_(False)
145
+ tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
146
+ text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
 
147
  text_encoder.requires_grad_(False)
 
148
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
149
+ cfg.pretrained_model_path,
150
  subfolder=cfg.unet_subfolder,
151
  unet_additional_kwargs=cfg.unet_additional_kwargs
152
  ).to(device)
 
158
  camera_adaptor.requires_grad_(False)
159
  camera_adaptor.to(device)
160
 
161
+ logger.info("Setting the attention processors")
162
  unet.set_all_attn_processor(
163
  add_spatial_lora=cfg.lora_ckpt is not None,
164
  add_motion_lora=cfg.motion_lora_rank > 0,
 
168
  )
169
 
170
  if cfg.lora_ckpt is not None:
171
+ print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
172
+ lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
173
  if 'lora_state_dict' in lora_checkpoints.keys():
174
  lora_checkpoints = lora_checkpoints['lora_state_dict']
175
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
176
  assert len(lora_u) == 0
177
+ print(f'Loading done')
178
 
179
  if cfg.motion_module_ckpt is not None:
180
+ print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
181
+ mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
182
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
183
  assert len(mm_u) == 0
184
+ print("Loading done")
185
+
186
  if cfg.camera_adaptor_ckpt is not None:
187
+ logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
188
+ camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
189
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
190
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
191
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
192
+
193
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
194
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
195
  assert len(attention_processor_u) == 0
196
+
197
+ logger.info("Camera Adaptor loading done")
198
+ else:
199
+ logger.info("No Camera Adaptor checkpoint used")
200
 
201
  pipeline = GenPhotoPipeline(
202
  vae=vae,
 
206
  scheduler=noise_scheduler,
207
  camera_encoder=camera_encoder
208
  ).to(device)
 
209
  pipeline.enable_vae_slicing()
210
+
211
  return pipeline, device
212
 
213
+
214
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, focal_length_list, device, video_length=5, height=256, width=384):
215
 
216
  focal_length_values = json.loads(focal_length_list)
inference_shutter_speed.py CHANGED
@@ -22,11 +22,6 @@ from genphoto.utils.util import save_videos_grid
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
-
26
- from huggingface_hub import hf_hub_download
27
-
28
-
29
-
30
  def create_shutter_speed_embedding(shutter_speed_values, target_height, target_width, base_exposure=0.5):
31
  """
32
  Create a shutter_speed (Exposure Value or shutter speed) embedding tensor using a constant fwc value.
@@ -119,115 +114,32 @@ class Camera_Embedding(Dataset):
119
  return camera_embedding
120
 
121
 
122
- # def load_models(cfg):
123
- #
124
- # device = "cuda" if torch.cuda.is_available() else "cpu"
125
- #
126
- # noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
127
- # vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
128
- # vae.requires_grad_(False)
129
- # tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
130
- # text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
131
- # text_encoder.requires_grad_(False)
132
- #
133
- # unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
134
- # cfg.pretrained_model_path,
135
- # subfolder=cfg.unet_subfolder,
136
- # unet_additional_kwargs=cfg.unet_additional_kwargs
137
- # ).to(device)
138
- # unet.requires_grad_(False)
139
- #
140
- #
141
- # camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
142
- # camera_encoder.requires_grad_(False)
143
- # camera_adaptor = CameraAdaptor(unet, camera_encoder)
144
- # camera_adaptor.requires_grad_(False)
145
- # camera_adaptor.to(device)
146
- #
147
- # logger.info("Setting the attention processors")
148
- # unet.set_all_attn_processor(
149
- # add_spatial_lora=cfg.lora_ckpt is not None,
150
- # add_motion_lora=cfg.motion_lora_rank > 0,
151
- # lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
152
- # motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
153
- # **cfg.attention_processor_kwargs
154
- # )
155
- #
156
- # if cfg.lora_ckpt is not None:
157
- # print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
158
- # lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
159
- # if 'lora_state_dict' in lora_checkpoints.keys():
160
- # lora_checkpoints = lora_checkpoints['lora_state_dict']
161
- # _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
162
- # assert len(lora_u) == 0
163
- # print(f'Loading done')
164
- #
165
- # if cfg.motion_module_ckpt is not None:
166
- # print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
167
- # mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
168
- # _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
169
- # assert len(mm_u) == 0
170
- # print("Loading done")
171
- #
172
- #
173
- # if cfg.camera_adaptor_ckpt is not None:
174
- # logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
175
- # camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
176
- #
177
- # camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
178
- # attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
179
- #
180
- # camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
181
- #
182
- # assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
183
- # _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
184
- # assert len(attention_processor_u) == 0
185
- #
186
- # logger.info("Camera Adaptor loading done")
187
- # else:
188
- # logger.info("No Camera Adaptor checkpoint used")
189
- #
190
- # pipeline = GenPhotoPipeline(
191
- # vae=vae,
192
- # text_encoder=text_encoder,
193
- # tokenizer=tokenizer,
194
- # unet=unet,
195
- # scheduler=noise_scheduler,
196
- # camera_encoder=camera_encoder
197
- # ).to(device)
198
- # pipeline.enable_vae_slicing()
199
- #
200
- # return pipeline, device
201
-
202
  def load_models(cfg):
203
- device = "cuda" if torch.cuda.is_available() else "cpu"
204
 
205
- pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
206
- lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
207
- motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
208
- camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-shutter_speed.ckpt")
209
 
210
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
211
- vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
212
  vae.requires_grad_(False)
213
-
214
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
215
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
216
  text_encoder.requires_grad_(False)
217
 
218
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
219
- pretrained_model_path,
220
  subfolder=cfg.unet_subfolder,
221
  unet_additional_kwargs=cfg.unet_additional_kwargs
222
  ).to(device)
223
  unet.requires_grad_(False)
224
 
 
225
  camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
226
  camera_encoder.requires_grad_(False)
227
  camera_adaptor = CameraAdaptor(unet, camera_encoder)
228
  camera_adaptor.requires_grad_(False)
229
  camera_adaptor.to(device)
230
 
 
231
  unet.set_all_attn_processor(
232
  add_spatial_lora=cfg.lora_ckpt is not None,
233
  add_motion_lora=cfg.motion_lora_rank > 0,
@@ -237,25 +149,40 @@ def load_models(cfg):
237
  )
238
 
239
  if cfg.lora_ckpt is not None:
240
- lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
 
241
  if 'lora_state_dict' in lora_checkpoints.keys():
242
  lora_checkpoints = lora_checkpoints['lora_state_dict']
243
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
244
  assert len(lora_u) == 0
 
245
 
246
  if cfg.motion_module_ckpt is not None:
247
- mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
 
248
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
249
  assert len(mm_u) == 0
 
 
250
 
 
251
  if cfg.camera_adaptor_ckpt is not None:
252
- camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
 
 
 
253
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
254
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
 
255
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
 
256
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
257
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
258
  assert len(attention_processor_u) == 0
 
 
 
 
259
 
260
  pipeline = GenPhotoPipeline(
261
  vae=vae,
@@ -265,10 +192,9 @@ def load_models(cfg):
265
  scheduler=noise_scheduler,
266
  camera_encoder=camera_encoder
267
  ).to(device)
268
-
269
  pipeline.enable_vae_slicing()
270
- return pipeline, device
271
 
 
272
 
273
 
274
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, shutter_speed_list, device, video_length=5, height=256, width=384):
 
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
 
 
 
 
 
25
  def create_shutter_speed_embedding(shutter_speed_values, target_height, target_width, base_exposure=0.5):
26
  """
27
  Create a shutter_speed (Exposure Value or shutter speed) embedding tensor using a constant fwc value.
 
114
  return camera_embedding
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def load_models(cfg):
 
118
 
119
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
120
 
121
  noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
122
+ vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
123
  vae.requires_grad_(False)
124
+ tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
125
+ text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
 
126
  text_encoder.requires_grad_(False)
127
 
128
  unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
129
+ cfg.pretrained_model_path,
130
  subfolder=cfg.unet_subfolder,
131
  unet_additional_kwargs=cfg.unet_additional_kwargs
132
  ).to(device)
133
  unet.requires_grad_(False)
134
 
135
+
136
  camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
137
  camera_encoder.requires_grad_(False)
138
  camera_adaptor = CameraAdaptor(unet, camera_encoder)
139
  camera_adaptor.requires_grad_(False)
140
  camera_adaptor.to(device)
141
 
142
+ logger.info("Setting the attention processors")
143
  unet.set_all_attn_processor(
144
  add_spatial_lora=cfg.lora_ckpt is not None,
145
  add_motion_lora=cfg.motion_lora_rank > 0,
 
149
  )
150
 
151
  if cfg.lora_ckpt is not None:
152
+ print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
153
+ lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
154
  if 'lora_state_dict' in lora_checkpoints.keys():
155
  lora_checkpoints = lora_checkpoints['lora_state_dict']
156
  _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
157
  assert len(lora_u) == 0
158
+ print(f'Loading done')
159
 
160
  if cfg.motion_module_ckpt is not None:
161
+ print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
162
+ mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
163
  _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
164
  assert len(mm_u) == 0
165
+ print("Loading done")
166
+
167
 
168
+ # 🔥 加载 Camera Adaptor Checkpoint
169
  if cfg.camera_adaptor_ckpt is not None:
170
+ logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
171
+ camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
172
+
173
+ # 加载 Camera Encoder
174
  camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
175
  attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
176
+
177
  camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
178
+
179
  assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
180
  _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
181
  assert len(attention_processor_u) == 0
182
+
183
+ logger.info("Camera Adaptor loading done")
184
+ else:
185
+ logger.info("No Camera Adaptor checkpoint used")
186
 
187
  pipeline = GenPhotoPipeline(
188
  vae=vae,
 
192
  scheduler=noise_scheduler,
193
  camera_encoder=camera_encoder
194
  ).to(device)
 
195
  pipeline.enable_vae_slicing()
 
196
 
197
+ return pipeline, device
198
 
199
 
200
  def run_inference(pipeline, tokenizer, text_encoder, base_scene, shutter_speed_list, device, video_length=5, height=256, width=384):
requirements.txt CHANGED
@@ -2,18 +2,18 @@
2
  torch==2.1.1
3
  torchvision==0.16.1
4
  torchaudio==2.1.1
5
- diffusers
6
  imageio==2.36.0
7
  imageio-ffmpeg
8
- transformers
9
- accelerate
10
  opencv-python
11
  gdown
12
  einops
13
  decord
14
  omegaconf
15
  safetensors
16
- gradio
17
  wandb
18
  triton
19
- huggingface_hub
 
2
  torch==2.1.1
3
  torchvision==0.16.1
4
  torchaudio==2.1.1
5
+ diffusers==0.24.0
6
  imageio==2.36.0
7
  imageio-ffmpeg
8
+ transformers==4.45.2
9
+ accelerate==1.0.1
10
  opencv-python
11
  gdown
12
  einops
13
  decord
14
  omegaconf
15
  safetensors
16
+ gradio==5.1.0
17
  wandb
18
  triton
19
+ huggingface_hub==0.25.2