joy-caption-pre-alpha / run_caption.py
svjack's picture
Update run_caption.py
48b0782 verified
raw
history blame
6.22 kB
import argparse
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from PIL import Image
import shutil
from tqdm import tqdm # 引入 tqdm 用于显示进度条
# Constants
CLIP_PATH = "google/siglip-so400m-patch14-384"
VLM_PROMPT = "A descriptive caption for this image:\n"
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
CHECKPOINT_PATH = Path("wpkklhc6")
# Image Adapter
class ImageAdapter(nn.Module):
def __init__(self, input_features: int, output_features: int):
super().__init__()
self.linear1 = nn.Linear(input_features, output_features)
self.activation = nn.GELU()
self.linear2 = nn.Linear(output_features, output_features)
def forward(self, vision_outputs: torch.Tensor):
x = self.linear1(vision_outputs)
x = self.activation(x)
x = self.linear2(x)
return x
# Load models
def load_models():
print("Loading CLIP")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH)
clip_model = clip_model.vision_model
clip_model.eval()
clip_model.requires_grad_(False)
clip_model.to("cuda")
print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
print("Loading LLM")
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
text_model.eval()
print("Loading image adapter")
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
image_adapter.eval()
image_adapter.to("cuda")
return clip_processor, clip_model, tokenizer, text_model, image_adapter
# Generate caption
@torch.no_grad()
def generate_caption(input_image: Image.Image, clip_processor, clip_model, tokenizer, text_model, image_adapter):
torch.cuda.empty_cache()
# Preprocess image
image = clip_processor(images=input_image, return_tensors='pt').pixel_values
image = image.to('cuda')
# Tokenize the prompt
prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
# Embed image
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
image_features = vision_outputs.hidden_states[-2]
embedded_images = image_adapter(image_features)
embedded_images = embedded_images.to('cuda')
# Embed prompt
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
# Construct prompts
inputs_embeds = torch.cat([
embedded_bos.expand(embedded_images.shape[0], -1, -1),
embedded_images.to(dtype=embedded_bos.dtype),
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
], dim=1)
input_ids = torch.cat([
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
prompt,
], dim=1).to('cuda')
attention_mask = torch.ones_like(input_ids)
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
# Trim off the prompt
generate_ids = generate_ids[:, input_ids.shape[1]:]
if generate_ids[0][-1] == tokenizer.eos_token_id:
generate_ids = generate_ids[:, :-1]
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
return caption.strip()
# Main function
def main():
parser = argparse.ArgumentParser(description="Generate a caption for an image and save it to a text file.")
parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images")
parser.add_argument("output_path", type=str, help="Path to save the output images and captions")
args = parser.parse_args()
# Load models
clip_processor, clip_model, tokenizer, text_model, image_adapter = load_models()
# Determine if input is a directory or a single file
input_path = Path(args.input_path)
if input_path.is_dir():
image_paths = list(input_path.glob("*.[pjP][npP][gG]")) + list(input_path.glob("*.[jJ][pP][eE][gG]")) # 支持 PNG 和 JPEG 格式
else:
image_paths = [input_path]
# Create output directory if it doesn't exist
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
# Process each image
for image_path in tqdm(image_paths, desc="Processing images"):
try:
# Open the input image
input_image = Image.open(image_path)
# Generate caption
caption = generate_caption(input_image, clip_processor, clip_model, tokenizer, text_model, image_adapter)
# Copy image to output path
image_name = image_path.name.replace(" ", "_") # Replace spaces with underscores
output_image_path = output_path / image_name
shutil.copy(image_path, output_image_path)
# Save caption to txt file
txt_file_path = output_path / f"{output_image_path.stem}.txt"
with open(txt_file_path, "w") as f:
f.write(caption)
print(f"Caption saved to {txt_file_path}")
except Exception as e:
print(f"Error processing {image_path}: {e}")
if __name__ == "__main__":
main()