import argparse from pathlib import Path import torch from torch import nn from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM from PIL import Image import shutil from tqdm import tqdm # 引入 tqdm 用于显示进度条 # Constants CLIP_PATH = "google/siglip-so400m-patch14-384" VLM_PROMPT = "A descriptive caption for this image:\n" MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B" CHECKPOINT_PATH = Path("wpkklhc6") # Image Adapter class ImageAdapter(nn.Module): def __init__(self, input_features: int, output_features: int): super().__init__() self.linear1 = nn.Linear(input_features, output_features) self.activation = nn.GELU() self.linear2 = nn.Linear(output_features, output_features) def forward(self, vision_outputs: torch.Tensor): x = self.linear1(vision_outputs) x = self.activation(x) x = self.linear2(x) return x # Load models def load_models(): print("Loading CLIP") clip_processor = AutoProcessor.from_pretrained(CLIP_PATH) clip_model = AutoModel.from_pretrained(CLIP_PATH) clip_model = clip_model.vision_model clip_model.eval() clip_model.requires_grad_(False) clip_model.to("cuda") print("Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}" print("Loading LLM") text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16) text_model.eval() print("Loading image adapter") image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size) image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")) image_adapter.eval() image_adapter.to("cuda") return clip_processor, clip_model, tokenizer, text_model, image_adapter # Generate caption @torch.no_grad() def generate_caption(input_image: Image.Image, clip_processor, clip_model, tokenizer, text_model, image_adapter): torch.cuda.empty_cache() # Preprocess image image = clip_processor(images=input_image, return_tensors='pt').pixel_values image = image.to('cuda') # Tokenize the prompt prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False) # Embed image with torch.amp.autocast_mode.autocast('cuda', enabled=True): vision_outputs = clip_model(pixel_values=image, output_hidden_states=True) image_features = vision_outputs.hidden_states[-2] embedded_images = image_adapter(image_features) embedded_images = embedded_images.to('cuda') # Embed prompt prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda')) assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}" embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)) # Construct prompts inputs_embeds = torch.cat([ embedded_bos.expand(embedded_images.shape[0], -1, -1), embedded_images.to(dtype=embedded_bos.dtype), prompt_embeds.expand(embedded_images.shape[0], -1, -1), ], dim=1) input_ids = torch.cat([ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long), torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), prompt, ], dim=1).to('cuda') attention_mask = torch.ones_like(input_ids) generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None) # Trim off the prompt generate_ids = generate_ids[:, input_ids.shape[1]:] if generate_ids[0][-1] == tokenizer.eos_token_id: generate_ids = generate_ids[:, :-1] caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0] return caption.strip() # Main function def main(): parser = argparse.ArgumentParser(description="Generate a caption for an image and save it to a text file.") parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images") parser.add_argument("output_path", type=str, help="Path to save the output images and captions") args = parser.parse_args() # Load models clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models() # Determine if input is a directory or a single file input_path = Path(args.input_path) if input_path.is_dir(): image_paths = list(input_path.glob("*.[pjP][npP][gG]")) + list(input_path.glob("*.[jJ][pP][eE][gG]")) # 支持 PNG 和 JPEG 格式 else: image_paths = [input_path] # Create output directory if it doesn't exist output_path = Path(args.output_path) output_path.mkdir(parents=True, exist_ok=True) # Process each image for image_path in tqdm(image_paths, desc="Processing images"): try: # Open the input image input_image = Image.open(image_path) # Generate caption caption = generate_caption(input_image, clip_processor, clip_model, tokenizer, text_model, image_adapter) # Copy image to output path image_name = image_path.name.replace(" ", "_") # Replace spaces with underscores output_image_path = output_path / image_name shutil.copy(image_path, output_image_path) # Save caption to txt file txt_file_path = output_path / f"{output_image_path.stem}.txt" with open(txt_file_path, "w") as f: f.write(caption) print(f"Caption saved to {txt_file_path}") except Exception as e: print(f"Error processing {image_path}: {e}") if __name__ == "__main__": main()