svjack commited on
Commit
48b0782
·
verified ·
1 Parent(s): 13f60c5

Update run_caption.py

Browse files
Files changed (1) hide show
  1. run_caption.py +36 -18
run_caption.py CHANGED
@@ -5,12 +5,15 @@ from torch import nn
5
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
6
  from PIL import Image
7
  import shutil
 
8
 
 
9
  CLIP_PATH = "google/siglip-so400m-patch14-384"
10
  VLM_PROMPT = "A descriptive caption for this image:\n"
11
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
12
  CHECKPOINT_PATH = Path("wpkklhc6")
13
 
 
14
  class ImageAdapter(nn.Module):
15
  def __init__(self, input_features: int, output_features: int):
16
  super().__init__()
@@ -24,6 +27,7 @@ class ImageAdapter(nn.Module):
24
  x = self.linear2(x)
25
  return x
26
 
 
27
  def load_models():
28
  print("Loading CLIP")
29
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
@@ -49,6 +53,7 @@ def load_models():
49
 
50
  return clip_processor, clip_model, tokenizer, text_model, image_adapter
51
 
 
52
  @torch.no_grad()
53
  def generate_caption(input_image: Image.Image, clip_processor, clip_model, tokenizer, text_model, image_adapter):
54
  torch.cuda.empty_cache()
@@ -97,37 +102,50 @@ def generate_caption(input_image: Image.Image, clip_processor, clip_model, token
97
 
98
  return caption.strip()
99
 
 
100
  def main():
101
  parser = argparse.ArgumentParser(description="Generate a caption for an image and save it to a text file.")
102
- parser.add_argument("input_image", type=str, help="Path to the input image")
103
- parser.add_argument("output_path", type=str, help="Path to save the output image and caption text file")
104
  args = parser.parse_args()
105
 
106
  # Load models
107
  clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
108
 
109
- # Open the input image
110
- input_image = Image.open(args.input_image)
 
 
 
 
111
 
112
- # Generate caption
113
- caption = generate_caption(input_image, clip_processor, clip_model, tokenizer, text_model, image_adapter)
114
-
115
- # Process output path
116
  output_path = Path(args.output_path)
117
  output_path.mkdir(parents=True, exist_ok=True)
118
 
119
- # Copy image to output path
120
- image_name = Path(args.input_image).name
121
- image_name = image_name.replace(" ", "_") # Replace spaces with underscores
122
- output_image_path = output_path / image_name
123
- shutil.copy(args.input_image, output_image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Save caption to txt file
126
- txt_file_path = output_path / f"{output_image_path.stem}.txt"
127
- with open(txt_file_path, "w") as f:
128
- f.write(caption)
129
 
130
- print(f"Caption saved to {txt_file_path}")
 
131
 
132
  if __name__ == "__main__":
133
  main()
 
5
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
6
  from PIL import Image
7
  import shutil
8
+ from tqdm import tqdm # 引入 tqdm 用于显示进度条
9
 
10
+ # Constants
11
  CLIP_PATH = "google/siglip-so400m-patch14-384"
12
  VLM_PROMPT = "A descriptive caption for this image:\n"
13
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
14
  CHECKPOINT_PATH = Path("wpkklhc6")
15
 
16
+ # Image Adapter
17
  class ImageAdapter(nn.Module):
18
  def __init__(self, input_features: int, output_features: int):
19
  super().__init__()
 
27
  x = self.linear2(x)
28
  return x
29
 
30
+ # Load models
31
  def load_models():
32
  print("Loading CLIP")
33
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
 
53
 
54
  return clip_processor, clip_model, tokenizer, text_model, image_adapter
55
 
56
+ # Generate caption
57
  @torch.no_grad()
58
  def generate_caption(input_image: Image.Image, clip_processor, clip_model, tokenizer, text_model, image_adapter):
59
  torch.cuda.empty_cache()
 
102
 
103
  return caption.strip()
104
 
105
+ # Main function
106
  def main():
107
  parser = argparse.ArgumentParser(description="Generate a caption for an image and save it to a text file.")
108
+ parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images")
109
+ parser.add_argument("output_path", type=str, help="Path to save the output images and captions")
110
  args = parser.parse_args()
111
 
112
  # Load models
113
  clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
114
 
115
+ # Determine if input is a directory or a single file
116
+ input_path = Path(args.input_path)
117
+ if input_path.is_dir():
118
+ image_paths = list(input_path.glob("*.[pjP][npP][gG]")) + list(input_path.glob("*.[jJ][pP][eE][gG]")) # 支持 PNG 和 JPEG 格式
119
+ else:
120
+ image_paths = [input_path]
121
 
122
+ # Create output directory if it doesn't exist
 
 
 
123
  output_path = Path(args.output_path)
124
  output_path.mkdir(parents=True, exist_ok=True)
125
 
126
+ # Process each image
127
+ for image_path in tqdm(image_paths, desc="Processing images"):
128
+ try:
129
+ # Open the input image
130
+ input_image = Image.open(image_path)
131
+
132
+ # Generate caption
133
+ caption = generate_caption(input_image, clip_processor, clip_model, tokenizer, text_model, image_adapter)
134
+
135
+ # Copy image to output path
136
+ image_name = image_path.name.replace(" ", "_") # Replace spaces with underscores
137
+ output_image_path = output_path / image_name
138
+ shutil.copy(image_path, output_image_path)
139
+
140
+ # Save caption to txt file
141
+ txt_file_path = output_path / f"{output_image_path.stem}.txt"
142
+ with open(txt_file_path, "w") as f:
143
+ f.write(caption)
144
 
145
+ print(f"Caption saved to {txt_file_path}")
 
 
 
146
 
147
+ except Exception as e:
148
+ print(f"Error processing {image_path}: {e}")
149
 
150
  if __name__ == "__main__":
151
  main()