File size: 6,220 Bytes
13f60c5
 
 
 
 
 
 
48b0782
13f60c5
48b0782
13f60c5
 
 
 
 
48b0782
13f60c5
 
 
 
 
 
 
 
 
 
 
 
 
48b0782
13f60c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48b0782
13f60c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48b0782
13f60c5
 
48b0782
 
13f60c5
 
 
 
 
48b0782
 
 
 
 
 
13f60c5
48b0782
13f60c5
 
 
48b0782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f60c5
48b0782
13f60c5
48b0782
 
13f60c5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from PIL import Image
import shutil
from tqdm import tqdm  # 引入 tqdm 用于显示进度条

# Constants
CLIP_PATH = "google/siglip-so400m-patch14-384"
VLM_PROMPT = "A descriptive caption for this image:\n"
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
CHECKPOINT_PATH = Path("wpkklhc6")

# Image Adapter
class ImageAdapter(nn.Module):
    def __init__(self, input_features: int, output_features: int):
        super().__init__()
        self.linear1 = nn.Linear(input_features, output_features)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(output_features, output_features)
    
    def forward(self, vision_outputs: torch.Tensor):
        x = self.linear1(vision_outputs)
        x = self.activation(x)
        x = self.linear2(x)
        return x

# Load models
def load_models():
    print("Loading CLIP")
    clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
    clip_model = AutoModel.from_pretrained(CLIP_PATH)
    clip_model = clip_model.vision_model
    clip_model.eval()
    clip_model.requires_grad_(False)
    clip_model.to("cuda")

    print("Loading tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
    assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"

    print("Loading LLM")
    text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
    text_model.eval()

    print("Loading image adapter")
    image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
    image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
    image_adapter.eval()
    image_adapter.to("cuda")

    return clip_processor, clip_model, tokenizer, text_model, image_adapter

# Generate caption
@torch.no_grad()
def generate_caption(input_image: Image.Image, clip_processor, clip_model, tokenizer, text_model, image_adapter):
    torch.cuda.empty_cache()

    # Preprocess image
    image = clip_processor(images=input_image, return_tensors='pt').pixel_values
    image = image.to('cuda')

    # Tokenize the prompt
    prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)

    # Embed image
    with torch.amp.autocast_mode.autocast('cuda', enabled=True):
        vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
        image_features = vision_outputs.hidden_states[-2]
        embedded_images = image_adapter(image_features)
        embedded_images = embedded_images.to('cuda')
    
    # Embed prompt
    prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
    assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
    embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))

    # Construct prompts
    inputs_embeds = torch.cat([
        embedded_bos.expand(embedded_images.shape[0], -1, -1),
        embedded_images.to(dtype=embedded_bos.dtype),
        prompt_embeds.expand(embedded_images.shape[0], -1, -1),
    ], dim=1)

    input_ids = torch.cat([
        torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
        torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
        prompt,
    ], dim=1).to('cuda')
    attention_mask = torch.ones_like(input_ids)

    generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)

    # Trim off the prompt
    generate_ids = generate_ids[:, input_ids.shape[1]:]
    if generate_ids[0][-1] == tokenizer.eos_token_id:
        generate_ids = generate_ids[:, :-1]

    caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]

    return caption.strip()

# Main function
def main():
    parser = argparse.ArgumentParser(description="Generate a caption for an image and save it to a text file.")
    parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images")
    parser.add_argument("output_path", type=str, help="Path to save the output images and captions")
    args = parser.parse_args()

    # Load models
    clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()

    # Determine if input is a directory or a single file
    input_path = Path(args.input_path)
    if input_path.is_dir():
        image_paths = list(input_path.glob("*.[pjP][npP][gG]")) + list(input_path.glob("*.[jJ][pP][eE][gG]"))  # 支持 PNG 和 JPEG 格式
    else:
        image_paths = [input_path]

    # Create output directory if it doesn't exist
    output_path = Path(args.output_path)
    output_path.mkdir(parents=True, exist_ok=True)

    # Process each image
    for image_path in tqdm(image_paths, desc="Processing images"):
        try:
            # Open the input image
            input_image = Image.open(image_path)

            # Generate caption
            caption = generate_caption(input_image, clip_processor, clip_model, tokenizer, text_model, image_adapter)

            # Copy image to output path
            image_name = image_path.name.replace(" ", "_")  # Replace spaces with underscores
            output_image_path = output_path / image_name
            shutil.copy(image_path, output_image_path)

            # Save caption to txt file
            txt_file_path = output_path / f"{output_image_path.stem}.txt"
            with open(txt_file_path, "w") as f:
                f.write(caption)

            print(f"Caption saved to {txt_file_path}")

        except Exception as e:
            print(f"Error processing {image_path}: {e}")

if __name__ == "__main__":
    main()