svjack commited on
Commit
2efc095
·
verified ·
1 Parent(s): f85df78

Create colorflow_cli.py

Browse files
Files changed (1) hide show
  1. colorflow_cli.py +77 -0
colorflow_cli.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python colorflow_cli.py \
3
+ --input_image ./input.jpg \
4
+ --reference_images ./ref1.jpg ./ref2.jpg \
5
+ --output_dir ./results \
6
+ --input_style Sketch \
7
+ --resolution 640x640 \
8
+ --seed 123 \
9
+ --num_inference_steps 20
10
+ '''
11
+
12
+ # colorflow_cli.py
13
+ from app_func import *
14
+ import argparse
15
+ import torch
16
+ from PIL import Image
17
+ import os
18
+ import logging
19
+
20
+ # 原文件中的必要导入和函数定义(需保留原文件中的核心逻辑)
21
+ # ... [保留原文件中的模型加载、extract_line_image、colorize_image等函数] ...
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="ColorFlow命令行图像上色工具")
25
+ parser.add_argument("--input_image", type=str, required=True, help="输入图像路径")
26
+ parser.add_argument("--reference_images", type=str, nargs='+', required=True, help="参考图像路径列表")
27
+ parser.add_argument("--output_dir", type=str, default="./output", help="输出目录")
28
+ parser.add_argument("--input_style", type=str, default="GrayImage(ScreenStyle)",
29
+ choices=["GrayImage(ScreenStyle)", "Sketch"], help="输入样式类型")
30
+ parser.add_argument("--resolution", type=str, default="640x640",
31
+ choices=["640x640", "512x800", "800x512"], help="分辨率设置")
32
+ parser.add_argument("--seed", type=int, default=0, help="随机种子")
33
+ parser.add_argument("--num_inference_steps", type=int, default=10, help="推理步数")
34
+ return parser.parse_args()
35
+
36
+ def save_image(image: Image.Image, path: str, format: str = "PNG") -> None:
37
+ """安全保存图像并处理异常"""
38
+ try:
39
+ image.save(path, format=format)
40
+ logging.info(f"成功保存图像至: {path}")
41
+ except Exception as e:
42
+ logging.error(f"保存图像失败: {str(e)}")
43
+ raise
44
+
45
+ def main():
46
+ args = parse_args()
47
+ os.makedirs(args.output_dir, exist_ok=True)
48
+
49
+ # 初始化模型
50
+ global cur_input_style, pipeline, MultiResNetModel
51
+ cur_input_style = None
52
+ load_ckpt(args.input_style)
53
+
54
+ # 预处理输入图像
55
+ input_img = Image.open(args.input_image).convert("RGB")
56
+ input_context, extracted_line, _ = extract_line_image(input_img, args.input_style, args.resolution)
57
+
58
+ # 执行颜色化并获取全部结果
59
+ high_res_img, up_img, raw_output, preprocessed_bw = colorize_image(
60
+ VAE_input=extracted_line,
61
+ input_context=input_context,
62
+ reference_images=args.reference_images,
63
+ resolution=args.resolution,
64
+ seed=args.seed,
65
+ input_style=args.input_style,
66
+ num_inference_steps=args.num_inference_steps
67
+ )
68
+
69
+ # 保存所有结果
70
+ save_image(high_res_img, os.path.join(args.output_dir, "colorized_result.png"))
71
+ save_image(up_img, os.path.join(args.output_dir, "upsampled_intermediate.png"))
72
+ save_image(raw_output, os.path.join(args.output_dir, "raw_generated_output.png"))
73
+ save_image(preprocessed_bw, os.path.join(args.output_dir, "preprocessed_bw.png"))
74
+
75
+ if __name__ == "__main__":
76
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
77
+ main()