VNAT
commited on
Commit
·
daba3f8
1
Parent(s):
a142643
Add batched_inference.py with some goodies
Browse files- batched_inference.py +180 -0
batched_inference.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.multiprocessing as multiprocessing
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from torch import autocast
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms import InterpolationMode
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
|
12 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
13 |
+
torch.backends.cudnn.allow_tf32 = True
|
14 |
+
torch.autograd.set_detect_anomaly(False)
|
15 |
+
torch.autograd.profiler.emit_nvtx(enabled=False)
|
16 |
+
torch.autograd.profiler.profile(enabled=False)
|
17 |
+
torch.backends.cudnn.benchmark = True
|
18 |
+
|
19 |
+
|
20 |
+
class ImageDataset(Dataset):
|
21 |
+
def __init__(self, image_folder_path, allowed_extensions):
|
22 |
+
self.allowed_extensions = allowed_extensions
|
23 |
+
self.all_image_paths, self.all_image_names, self.image_base_paths = self.get_image_paths(image_folder_path)
|
24 |
+
self.train_size = len(self.all_image_paths)
|
25 |
+
print(f"Number of images to be tagged: {self.train_size}")
|
26 |
+
self.thin_transform = transforms.Compose([
|
27 |
+
transforms.Resize(448, interpolation=InterpolationMode.BICUBIC),
|
28 |
+
transforms.CenterCrop(448),
|
29 |
+
transforms.ToTensor(),
|
30 |
+
# Normalize image
|
31 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
32 |
+
])
|
33 |
+
self.normal_transform = transforms.Compose([
|
34 |
+
transforms.Resize((448, 448), interpolation=InterpolationMode.BICUBIC),
|
35 |
+
transforms.ToTensor(),
|
36 |
+
# Normalize image
|
37 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
|
39 |
+
])
|
40 |
+
|
41 |
+
def get_image_paths(self, folder_path):
|
42 |
+
image_paths = []
|
43 |
+
image_file_names = []
|
44 |
+
image_base_paths = []
|
45 |
+
for root, dirs, files in os.walk(folder_path):
|
46 |
+
for file in files:
|
47 |
+
if file.lower().split(".")[-1] in self.allowed_extensions:
|
48 |
+
image_paths.append((os.path.abspath(os.path.join(root, file))))
|
49 |
+
image_file_names.append(file.split(".")[0])
|
50 |
+
image_base_paths.append(root)
|
51 |
+
return image_paths, image_file_names, image_base_paths
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.all_image_paths)
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
image = Image.open(self.all_image_paths[index]).convert("RGB")
|
58 |
+
ratio = image.height / image.width
|
59 |
+
if ratio > 2.0 or ratio < 0.5:
|
60 |
+
image = self.thin_transform(image)
|
61 |
+
else:
|
62 |
+
image = self.normal_transform(image)
|
63 |
+
|
64 |
+
return {
|
65 |
+
'image': image,
|
66 |
+
"image_name": self.all_image_names[index],
|
67 |
+
"image_root": self.image_base_paths[index]
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def prepare_model(model_path: str):
|
72 |
+
model = torch.load(model_path)
|
73 |
+
model.to(memory_format=torch.channels_last)
|
74 |
+
model = model.eval()
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
def train(tagging_is_running, model, dataloader, train_data, output_queue):
|
79 |
+
print('Begin tagging')
|
80 |
+
model.eval()
|
81 |
+
counter = 0
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
|
85 |
+
this_data = data['image'].to("cuda")
|
86 |
+
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
87 |
+
outputs = model(this_data)
|
88 |
+
|
89 |
+
probabilities = torch.nn.functional.sigmoid(outputs)
|
90 |
+
output_queue.put((probabilities.to("cpu"), data["image_name"], data["image_root"]))
|
91 |
+
|
92 |
+
counter += 1
|
93 |
+
_ = tagging_is_running.get()
|
94 |
+
print("Tagging finished!")
|
95 |
+
|
96 |
+
|
97 |
+
def tag_writer(tagging_is_running, output_queue, threshold):
|
98 |
+
with open("tags_8034.json", "r") as file:
|
99 |
+
tags = json.load(file)
|
100 |
+
allowed_tags = sorted(tags)
|
101 |
+
del tags
|
102 |
+
allowed_tags.extend(["placeholder0"])
|
103 |
+
tag_count = len(allowed_tags)
|
104 |
+
assert tag_count == 8035, f"The length of tag list is not correct. Correct: 8035, current: {tag_count}"
|
105 |
+
|
106 |
+
while not (tagging_is_running.qsize() > 0 and output_queue.qsize() > 0):
|
107 |
+
tag_probabilities, image_names, image_roots = output_queue.get()
|
108 |
+
tag_probabilities = tag_probabilities.tolist()
|
109 |
+
|
110 |
+
for per_image_tag_probabilities, image_name, image_root in zip(tag_probabilities, image_names, image_roots,
|
111 |
+
strict=True):
|
112 |
+
this_image_tags = []
|
113 |
+
this_image_tag_probabilities = []
|
114 |
+
for index, per_tag_probability in enumerate(per_image_tag_probabilities):
|
115 |
+
if per_tag_probability > threshold:
|
116 |
+
tag = allowed_tags[index]
|
117 |
+
if "placeholder" not in tag:
|
118 |
+
this_image_tags.append(tag)
|
119 |
+
this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
|
120 |
+
output_file = os.path.join(image_root, os.path.splitext(image_name)[0] + ".txt")
|
121 |
+
with open(output_file, "w", encoding="utf-8") as this_output:
|
122 |
+
# set this to true if you want tags separated with commas instead of spaces (will output "tag0, tag1...")
|
123 |
+
use_comma_sep = True
|
124 |
+
sep = " "
|
125 |
+
if use_comma_sep:
|
126 |
+
sep = ", "
|
127 |
+
# set this to true if you want to replace underscores with spaces
|
128 |
+
remove_underscores = True
|
129 |
+
if remove_underscores:
|
130 |
+
this_image_tags = map(lambda e: e.replace('_', ' '), this_image_tags)
|
131 |
+
this_output.write(sep.join(this_image_tags))
|
132 |
+
# change output_probabilities to True if you want probabilities
|
133 |
+
output_probabilities = False
|
134 |
+
if output_probabilities:
|
135 |
+
this_output.write("\n")
|
136 |
+
this_output.write(sep.join(this_image_tag_probabilities))
|
137 |
+
|
138 |
+
|
139 |
+
def main():
|
140 |
+
image_folder_path = "/path/to/img/folder"
|
141 |
+
# all images should be in this folder and/or its subfolders.
|
142 |
+
# I will generate a text file for every image.
|
143 |
+
model_path = "/path/to/your/model.pth"
|
144 |
+
allowed_extensions = {"jpg", "jpeg", "png", "webp"}
|
145 |
+
batch_size = 64
|
146 |
+
# if you have a 24GB card, you can try 256
|
147 |
+
threshold = 0.3
|
148 |
+
|
149 |
+
multiprocessing.set_start_method('spawn')
|
150 |
+
output_queue = multiprocessing.Queue()
|
151 |
+
tagging_is_running = multiprocessing.Queue(maxsize=5)
|
152 |
+
tagging_is_running.put("Running!")
|
153 |
+
|
154 |
+
if not torch.cuda.is_available():
|
155 |
+
raise RuntimeError("CUDA is not available!")
|
156 |
+
|
157 |
+
model = prepare_model(model_path).to("cuda")
|
158 |
+
|
159 |
+
dataset = ImageDataset(image_folder_path, allowed_extensions)
|
160 |
+
|
161 |
+
batched_loader = DataLoader(
|
162 |
+
dataset,
|
163 |
+
batch_size=batch_size,
|
164 |
+
shuffle=False,
|
165 |
+
num_workers=12, # if you have a big batch size, a good cpu, and enough cpu memory, try 12
|
166 |
+
pin_memory=True,
|
167 |
+
drop_last=False,
|
168 |
+
)
|
169 |
+
process_writer = multiprocessing.Process(target=tag_writer,
|
170 |
+
args=(tagging_is_running, output_queue, threshold))
|
171 |
+
process_writer.start()
|
172 |
+
process_tagger = multiprocessing.Process(target=train,
|
173 |
+
args=(tagging_is_running, model, batched_loader, dataset, output_queue,))
|
174 |
+
process_tagger.start()
|
175 |
+
process_writer.join()
|
176 |
+
process_tagger.join()
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
main()
|