File size: 6,742 Bytes
2f0271b
 
 
 
 
 
77fe89a
 
 
 
 
 
 
 
 
 
 
2f0271b
 
77fe89a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6590a5e
 
 
 
 
77fe89a
 
6590a5e
 
77fe89a
 
 
6590a5e
 
 
77fe89a
6590a5e
 
 
77fe89a
6590a5e
 
 
77fe89a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
'''
git clone https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B

python run_caption_ds.py "svjack/Genshin-Impact-Couple-with-Tags-IID-Gender-Only-Two" --caption_column="joy-caption" --output_path="gen_couple_cap_dir"
'''

import argparse
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from datasets import load_dataset  # 引入 Hugging Face Dataset
from tqdm import tqdm  # 引入 tqdm 用于显示进度条

# Constants
CLIP_PATH = "google/siglip-so400m-patch14-384"
VLM_PROMPT = "A descriptive caption for this image:\n"
#MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
MODEL_PATH = "Meta-Llama-3.1-8B"
CHECKPOINT_PATH = Path("wpkklhc6")

# Image Adapter
class ImageAdapter(nn.Module):
    def __init__(self, input_features: int, output_features: int):
        super().__init__()
        self.linear1 = nn.Linear(input_features, output_features)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(output_features, output_features)
    
    def forward(self, vision_outputs: torch.Tensor):
        x = self.linear1(vision_outputs)
        x = self.activation(x)
        x = self.linear2(x)
        return x

# Load models
def load_models():
    print("Loading CLIP")
    clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
    clip_model = AutoModel.from_pretrained(CLIP_PATH)
    clip_model = clip_model.vision_model
    clip_model.eval()
    clip_model.requires_grad_(False)
    clip_model.to("cuda")

    print("Loading tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
    assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"

    print("Loading LLM")
    text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
    text_model.eval()

    print("Loading image adapter")
    image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
    image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
    image_adapter.eval()
    image_adapter.to("cuda")

    return clip_processor, clip_model, tokenizer, text_model, image_adapter

# Generate caption
@torch.no_grad()
def generate_caption(input_image, clip_processor, clip_model, tokenizer, text_model, image_adapter):
    torch.cuda.empty_cache()

    # Preprocess image
    image = clip_processor(images=input_image, return_tensors='pt').pixel_values
    image = image.to('cuda')

    # Tokenize the prompt
    prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)

    # Embed image
    with torch.amp.autocast_mode.autocast('cuda', enabled=True):
        vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
        image_features = vision_outputs.hidden_states[-2]
        embedded_images = image_adapter(image_features)
        embedded_images = embedded_images.to('cuda')
    
    # Embed prompt
    prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
    assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
    embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))

    # Construct prompts
    inputs_embeds = torch.cat([
        embedded_bos.expand(embedded_images.shape[0], -1, -1),
        embedded_images.to(dtype=embedded_bos.dtype),
        prompt_embeds.expand(embedded_images.shape[0], -1, -1),
    ], dim=1)

    input_ids = torch.cat([
        torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
        torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
        prompt,
    ], dim=1).to('cuda')
    attention_mask = torch.ones_like(input_ids)

    generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)

    # Trim off the prompt
    generate_ids = generate_ids[:, input_ids.shape[1]:]
    if generate_ids[0][-1] == tokenizer.eos_token_id:
        generate_ids = generate_ids[:, :-1]

    caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]

    return caption.strip()

# Main function
def main():
    parser = argparse.ArgumentParser(description="Generate captions for images in a Hugging Face Dataset.")
    parser.add_argument("dataset_name", type=str, help="Name of the Hugging Face Dataset")
    parser.add_argument("--image_column", type=str, default="image", help="Name of the column containing images (default: 'image')")
    parser.add_argument("--caption_column", type=str, default="caption", help="Name of the column to save captions (default: 'caption')")
    parser.add_argument("--output_path", type=str, required=True, help="Path to save the dataset with captions")
    args = parser.parse_args()

    # Load models
    clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()

    # Load dataset
    print(f"Loading dataset: {args.dataset_name}")
    dataset = load_dataset(args.dataset_name)
    len_ = len(dataset["train"])
    #len_ = 10
    
    # Initialize a list to store captions
    captions = []

    # Generate captions for each image in the dataset
    print("Generating captions...")
    for idx, example in enumerate(tqdm(dataset["train"].select(range(len_)), desc="Processing images")):  # 假设数据集是 "train" 拆分
        try:
            # Generate caption
            caption = generate_caption(example[args.image_column], clip_processor, clip_model, tokenizer, text_model, image_adapter)
            captions.append(caption)
            # Print the generated caption
            print(f"Caption for image {idx + 1}: {caption}")
        except Exception as e:
            print(f"Error processing image {idx + 1}: {e}")
            captions.append("")  # 如果出错,保存空字符串
            print(f"Caption for image {idx + 1}: [Error]")

    # Add captions to the dataset
    print("Adding captions to the dataset...")
    dataset = dataset["train"].select(range(len_)).add_column(args.caption_column, captions)  # 将 captions 添加到数据集

    # Save the dataset with captions
    print(f"Saving dataset to {args.output_path}")
    dataset.save_to_disk(args.output_path)

    print("Done!")

if __name__ == "__main__":
    main()