Spaces:
Runtime error
Runtime error
Update run_caption_ds.py
Browse files- run_caption_ds.py +16 -9
run_caption_ds.py
CHANGED
@@ -123,22 +123,29 @@ def main():
|
|
123 |
# Load dataset
|
124 |
print(f"Loading dataset: {args.dataset_name}")
|
125 |
dataset = load_dataset(args.dataset_name)
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
# Generate captions for each image in the dataset
|
128 |
-
|
|
|
129 |
try:
|
130 |
# Generate caption
|
131 |
caption = generate_caption(example[args.image_column], clip_processor, clip_model, tokenizer, text_model, image_adapter)
|
132 |
-
|
133 |
-
|
|
|
134 |
except Exception as e:
|
135 |
-
print(f"Error processing image: {e}")
|
136 |
-
|
137 |
-
|
138 |
|
139 |
-
#
|
140 |
-
print("
|
141 |
-
dataset = dataset.
|
142 |
|
143 |
# Save the dataset with captions
|
144 |
print(f"Saving dataset to {args.output_path}")
|
|
|
123 |
# Load dataset
|
124 |
print(f"Loading dataset: {args.dataset_name}")
|
125 |
dataset = load_dataset(args.dataset_name)
|
126 |
+
len_ = len(dataset["train"])
|
127 |
+
#len_ = 10
|
128 |
+
|
129 |
+
# Initialize a list to store captions
|
130 |
+
captions = []
|
131 |
|
132 |
# Generate captions for each image in the dataset
|
133 |
+
print("Generating captions...")
|
134 |
+
for idx, example in enumerate(tqdm(dataset["train"].select(range(len_)), desc="Processing images")): # 假设数据集是 "train" 拆分
|
135 |
try:
|
136 |
# Generate caption
|
137 |
caption = generate_caption(example[args.image_column], clip_processor, clip_model, tokenizer, text_model, image_adapter)
|
138 |
+
captions.append(caption)
|
139 |
+
# Print the generated caption
|
140 |
+
print(f"Caption for image {idx + 1}: {caption}")
|
141 |
except Exception as e:
|
142 |
+
print(f"Error processing image {idx + 1}: {e}")
|
143 |
+
captions.append("") # 如果出错,保存空字符串
|
144 |
+
print(f"Caption for image {idx + 1}: [Error]")
|
145 |
|
146 |
+
# Add captions to the dataset
|
147 |
+
print("Adding captions to the dataset...")
|
148 |
+
dataset = dataset["train"].select(range(len_)).add_column(args.caption_column, captions) # 将 captions 添加到数据集
|
149 |
|
150 |
# Save the dataset with captions
|
151 |
print(f"Saving dataset to {args.output_path}")
|