svjack commited on
Commit
6590a5e
·
verified ·
1 Parent(s): 2f0271b

Update run_caption_ds.py

Browse files
Files changed (1) hide show
  1. run_caption_ds.py +16 -9
run_caption_ds.py CHANGED
@@ -123,22 +123,29 @@ def main():
123
  # Load dataset
124
  print(f"Loading dataset: {args.dataset_name}")
125
  dataset = load_dataset(args.dataset_name)
 
 
 
 
 
126
 
127
  # Generate captions for each image in the dataset
128
- def add_caption(example):
 
129
  try:
130
  # Generate caption
131
  caption = generate_caption(example[args.image_column], clip_processor, clip_model, tokenizer, text_model, image_adapter)
132
- # Add caption to the example
133
- example[args.caption_column] = caption
 
134
  except Exception as e:
135
- print(f"Error processing image: {e}")
136
- example[args.caption_column] = "" # 如果出错,保存空字符串
137
- return example
138
 
139
- # Apply the function to the dataset
140
- print("Generating captions...")
141
- dataset = dataset.map(add_caption, desc="Generating captions")
142
 
143
  # Save the dataset with captions
144
  print(f"Saving dataset to {args.output_path}")
 
123
  # Load dataset
124
  print(f"Loading dataset: {args.dataset_name}")
125
  dataset = load_dataset(args.dataset_name)
126
+ len_ = len(dataset["train"])
127
+ #len_ = 10
128
+
129
+ # Initialize a list to store captions
130
+ captions = []
131
 
132
  # Generate captions for each image in the dataset
133
+ print("Generating captions...")
134
+ for idx, example in enumerate(tqdm(dataset["train"].select(range(len_)), desc="Processing images")): # 假设数据集是 "train" 拆分
135
  try:
136
  # Generate caption
137
  caption = generate_caption(example[args.image_column], clip_processor, clip_model, tokenizer, text_model, image_adapter)
138
+ captions.append(caption)
139
+ # Print the generated caption
140
+ print(f"Caption for image {idx + 1}: {caption}")
141
  except Exception as e:
142
+ print(f"Error processing image {idx + 1}: {e}")
143
+ captions.append("") # 如果出错,保存空字符串
144
+ print(f"Caption for image {idx + 1}: [Error]")
145
 
146
+ # Add captions to the dataset
147
+ print("Adding captions to the dataset...")
148
+ dataset = dataset["train"].select(range(len_)).add_column(args.caption_column, captions) # captions 添加到数据集
149
 
150
  # Save the dataset with captions
151
  print(f"Saving dataset to {args.output_path}")