Spaces:
Runtime error
Runtime error
File size: 6,742 Bytes
2f0271b 77fe89a 2f0271b 77fe89a 6590a5e 77fe89a 6590a5e 77fe89a 6590a5e 77fe89a 6590a5e 77fe89a 6590a5e 77fe89a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
'''
git clone https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B
python run_caption_ds.py "svjack/Genshin-Impact-Couple-with-Tags-IID-Gender-Only-Two" --caption_column="joy-caption" --output_path="gen_couple_cap_dir"
'''
import argparse
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from datasets import load_dataset # 引入 Hugging Face Dataset
from tqdm import tqdm # 引入 tqdm 用于显示进度条
# Constants
CLIP_PATH = "google/siglip-so400m-patch14-384"
VLM_PROMPT = "A descriptive caption for this image:\n"
#MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
MODEL_PATH = "Meta-Llama-3.1-8B"
CHECKPOINT_PATH = Path("wpkklhc6")
# Image Adapter
class ImageAdapter(nn.Module):
def __init__(self, input_features: int, output_features: int):
super().__init__()
self.linear1 = nn.Linear(input_features, output_features)
self.activation = nn.GELU()
self.linear2 = nn.Linear(output_features, output_features)
def forward(self, vision_outputs: torch.Tensor):
x = self.linear1(vision_outputs)
x = self.activation(x)
x = self.linear2(x)
return x
# Load models
def load_models():
print("Loading CLIP")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH)
clip_model = clip_model.vision_model
clip_model.eval()
clip_model.requires_grad_(False)
clip_model.to("cuda")
print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
print("Loading LLM")
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
text_model.eval()
print("Loading image adapter")
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
image_adapter.eval()
image_adapter.to("cuda")
return clip_processor, clip_model, tokenizer, text_model, image_adapter
# Generate caption
@torch.no_grad()
def generate_caption(input_image, clip_processor, clip_model, tokenizer, text_model, image_adapter):
torch.cuda.empty_cache()
# Preprocess image
image = clip_processor(images=input_image, return_tensors='pt').pixel_values
image = image.to('cuda')
# Tokenize the prompt
prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
# Embed image
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
image_features = vision_outputs.hidden_states[-2]
embedded_images = image_adapter(image_features)
embedded_images = embedded_images.to('cuda')
# Embed prompt
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
# Construct prompts
inputs_embeds = torch.cat([
embedded_bos.expand(embedded_images.shape[0], -1, -1),
embedded_images.to(dtype=embedded_bos.dtype),
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
], dim=1)
input_ids = torch.cat([
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
prompt,
], dim=1).to('cuda')
attention_mask = torch.ones_like(input_ids)
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
# Trim off the prompt
generate_ids = generate_ids[:, input_ids.shape[1]:]
if generate_ids[0][-1] == tokenizer.eos_token_id:
generate_ids = generate_ids[:, :-1]
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
return caption.strip()
# Main function
def main():
parser = argparse.ArgumentParser(description="Generate captions for images in a Hugging Face Dataset.")
parser.add_argument("dataset_name", type=str, help="Name of the Hugging Face Dataset")
parser.add_argument("--image_column", type=str, default="image", help="Name of the column containing images (default: 'image')")
parser.add_argument("--caption_column", type=str, default="caption", help="Name of the column to save captions (default: 'caption')")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the dataset with captions")
args = parser.parse_args()
# Load models
clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
# Load dataset
print(f"Loading dataset: {args.dataset_name}")
dataset = load_dataset(args.dataset_name)
len_ = len(dataset["train"])
#len_ = 10
# Initialize a list to store captions
captions = []
# Generate captions for each image in the dataset
print("Generating captions...")
for idx, example in enumerate(tqdm(dataset["train"].select(range(len_)), desc="Processing images")): # 假设数据集是 "train" 拆分
try:
# Generate caption
caption = generate_caption(example[args.image_column], clip_processor, clip_model, tokenizer, text_model, image_adapter)
captions.append(caption)
# Print the generated caption
print(f"Caption for image {idx + 1}: {caption}")
except Exception as e:
print(f"Error processing image {idx + 1}: {e}")
captions.append("") # 如果出错,保存空字符串
print(f"Caption for image {idx + 1}: [Error]")
# Add captions to the dataset
print("Adding captions to the dataset...")
dataset = dataset["train"].select(range(len_)).add_column(args.caption_column, captions) # 将 captions 添加到数据集
# Save the dataset with captions
print(f"Saving dataset to {args.output_path}")
dataset.save_to_disk(args.output_path)
print("Done!")
if __name__ == "__main__":
main() |