Spaces:
Sleeping
Sleeping
import base64 | |
import io | |
import multiprocessing | |
import os | |
import random | |
from argparse import ArgumentParser | |
from multiprocessing import Process | |
import numpy as np | |
import requests | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from scipy.ndimage import label, find_objects, grey_dilation | |
from torch.utils.data import Dataset, DataLoader | |
from tqdm import tqdm | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Blip2Processor, Blip2ForConditionalGeneration, \ | |
CLIPSegProcessor, CLIPSegForImageSegmentation | |
Image.MAX_IMAGE_PIXELS = 1000000000 | |
INSTRUCTION_KEY = "### Instruction:" | |
INPUT_KEY = "### Input:" | |
RESPONSE_KEY = "### Response:" | |
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
PROMPT_FOR_GENERATION_FORMAT = """{intro} | |
{instruction_key} | |
Extract all objects mentioned in the caption and separate them using commas. Exclude background elements (site, location, environment) and only include foreground objects. Ensure that only nouns are included and exclude adjectives entirely. | |
{input_key} | |
{input} | |
{response_key} | |
""".format( | |
intro=INTRO_BLURB, | |
instruction_key=INSTRUCTION_KEY, | |
instruction="{instruction}", | |
input_key=INPUT_KEY, | |
input="{input}", | |
response_key=RESPONSE_KEY, | |
) | |
def save_tsv(args, shard_id, shard, device): | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7" | |
random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
torch.cuda.set_device(device) | |
model_dtype = torch.float16 | |
# blip2 | |
blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b") | |
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=model_dtype) | |
blip2_model.eval().to(device) | |
# mpt | |
mpt_config = AutoConfig.from_pretrained('mosaicml/mpt-7b-instruct', trust_remote_code=True) | |
mpt_config.init_device = device | |
mpt_config.attn_config['attn_impl'] = args.attn_impl | |
mpt_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') | |
mpt_tokenizer.pad_token = mpt_tokenizer.eos_token | |
mpt_tokenizer.padding_side = 'left' | |
mpt_model = AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-instruct', config=mpt_config, | |
torch_dtype=model_dtype, trust_remote_code=True) | |
mpt_model.eval() | |
mpt_generate_kwargs = { | |
'max_new_tokens': args.max_new_tokens, | |
'temperature': args.temperature, | |
'top_p': args.top_p, | |
'top_k': args.top_k, | |
'repetition_penalty': args.repetition_penalty, | |
'no_repeat_ngram_size': args.no_repeat_ngram_size, | |
'use_cache': args.use_cache, | |
'do_sample': False if args.temperature == 0 else args.do_sample, | |
'eos_token_id': mpt_tokenizer.eos_token_id, | |
'pad_token_id': mpt_tokenizer.pad_token_id, | |
} | |
# clipseg | |
clipseg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined", torch_dtype=model_dtype) | |
clipseg_model.eval().to(device) | |
cnt = 0 | |
for image in tqdm(shard): | |
if image is None: | |
continue | |
if cnt % 1000 == 0: | |
# close previous file if any | |
if cnt > 0: | |
f.close() | |
f = open(os.path.join(args.output_dir, f"cnt_{args.machine_id}_{shard_id}_{cnt // 1000}.tsv"), "w", | |
encoding='utf-8') | |
cnt += 1 | |
blip2_input = blip2_processor(images=image, return_tensors="pt").to(device, model_dtype) | |
blip2_gen = blip2_model.generate(**blip2_input) | |
caption = blip2_processor.batch_decode(blip2_gen, skip_special_tokens=True)[0] \ | |
.replace('\t', '').replace('\n', '').strip() | |
# tag extraction | |
prompt = PROMPT_FOR_GENERATION_FORMAT.format(input=caption) | |
# Run HF generate | |
mpt_input = mpt_tokenizer(prompt, return_tensors='pt', padding=True) | |
for key, value in mpt_input.items(): | |
mpt_input[key] = value.to(device) | |
mpt_gen = mpt_model.generate( | |
input_ids=mpt_input['input_ids'], | |
attention_mask=mpt_input['attention_mask'], | |
**mpt_generate_kwargs, | |
) | |
tags = mpt_tokenizer.batch_decode(mpt_gen, skip_special_tokens=True)[0][len(prompt):] | |
if '#' in tags: | |
continue | |
tags = tags.split(",") | |
tags = [tag.replace('\t', '').replace('\n', '').strip() for tag in tags] | |
tags = [tag for tag in tags if len(tag) > 0 and tag in caption] | |
if len(tags) == 0: | |
continue | |
clipseg_input = clipseg_processor(text=tags, images=[image] * len(tags), padding=True, return_tensors="pt") | |
for key, value in clipseg_input.items(): | |
clipseg_input[key] = value.to(device) | |
if value.dtype == torch.float32: | |
clipseg_input[key] = value.to(device, model_dtype) | |
# predict | |
clipseg_gen = clipseg_model(**clipseg_input).logits | |
if len(tags) == 1: | |
clipseg_gen = clipseg_gen.unsqueeze(0) | |
image_size = image.height | |
# interpolate to original size | |
clipseg_gen = F.interpolate(clipseg_gen.unsqueeze(1), size=image_size, mode='bilinear') | |
masks = torch.sigmoid(clipseg_gen).squeeze(1) | |
masks = masks.cpu().numpy() | |
sub_images = [] | |
tags_to_keep = [] | |
# save the masked image | |
for mask_id, mask in enumerate(masks): | |
image_array = np.array(image) | |
thresholded_mask = mask > args.threshold | |
if thresholded_mask.max() == 0: | |
continue | |
thresholded_mask = grey_dilation(thresholded_mask, size=(image_size // 100, image_size // 100)) | |
labeled_matrix, num_features = label(thresholded_mask) | |
regions = find_objects(labeled_matrix) | |
sizes = [np.sum(thresholded_mask[region]) for region in regions] | |
max_index = np.argmax(sizes) | |
max_region = regions[max_index] | |
thresholded_mask[labeled_matrix != (max_index + 1)] = False | |
tags_to_keep.append(tags[mask_id]) | |
# Determine the dimensions of the region | |
y_start, y_stop = max_region[0].start, max_region[0].stop | |
x_start, x_stop = max_region[1].start, max_region[1].stop | |
height = y_stop - y_start | |
width = x_stop - x_start | |
# Calculate the desired side length for a square region | |
side_length = max(height, width) | |
# Calculate the center of the region | |
center_y = (y_start + y_stop) // 2 | |
center_x = (x_start + x_stop) // 2 | |
# Calculate the new boundaries for the region | |
new_y_start = center_y - (side_length // 2) | |
new_y_stop = new_y_start + side_length | |
new_x_start = center_x - (side_length // 2) | |
new_x_stop = new_x_start + side_length | |
# Adjust the boundaries if they exceed the image boundaries | |
if new_y_start < 0: | |
new_y_start = 0 | |
new_y_stop = side_length | |
elif new_y_stop > image_array.shape[0]: | |
new_y_start = image_array.shape[0] - side_length | |
new_y_stop = image_array.shape[0] | |
if new_x_start < 0: | |
new_x_start = 0 | |
new_x_stop = side_length | |
elif new_x_stop > image_array.shape[1]: | |
new_x_start = image_array.shape[1] - side_length | |
new_x_stop = image_array.shape[1] | |
# Create a new mask with the adjusted boundaries | |
object_image = image_array[new_y_start:new_y_stop, new_x_start:new_x_stop] | |
max_region_mask = thresholded_mask[new_y_start:new_y_stop, new_x_start:new_x_stop] | |
masked_image = object_image.copy() | |
masked_image[~max_region_mask] = 255 | |
object_image = Image.fromarray(object_image).resize((512, 512)) | |
masked_image = Image.fromarray(masked_image).resize((512, 512)) | |
sub_images.extend([object_image, masked_image]) | |
if len(sub_images) == 0: | |
continue | |
image = image.resize((512, 512)) | |
# encode image using base64 | |
buffer = io.BytesIO() | |
image.save(buffer, format='PNG') | |
image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
for j, im in enumerate(sub_images): | |
buffer = io.BytesIO() | |
im.save(buffer, format='PNG') | |
sub_images[j] = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
# write to tsv file | |
f.write('\t'.join([ | |
caption, | |
','.join(tags_to_keep), | |
image, | |
*sub_images | |
]) + '\n') | |
class OpenImageDataset(Dataset): | |
def __init__(self, url_data): | |
self.data = url_data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
try: | |
items = self.data[idx].split(',') | |
image = Image.open(requests.get(items[2], stream=True).raw).convert('RGB') | |
# caption | |
width, height = image.size | |
shortest_side = min(width, height) | |
left = (width - shortest_side) // 2 | |
top = (height - shortest_side) // 2 | |
right = left + shortest_side | |
bottom = top + shortest_side | |
image = image.crop((left, top, right, bottom)) | |
return image | |
except: | |
return None | |
def collate_fn(batch): | |
return batch[0] if batch is not None else None | |
def main(): | |
"""Parse commandline arguments.""" | |
parser = ArgumentParser() | |
parser.add_argument('--data-dir', type=str, | |
default='/path/to/image_ids_and_rotation.csv') | |
parser.add_argument('--output-dir', type=str, default='/path/to/output-dir/') | |
parser.add_argument('--num-process', type=int, default=8) | |
parser.add_argument('--cuda-device', type=list, default=[0, 1, 2, 3, 4, 5, 6, 7]) | |
parser.add_argument('--num-machine', type=int, default=1) | |
parser.add_argument('--machine-id', type=int, default=0) | |
parser.add_argument('--max-seq-len', type=int, default=None) | |
parser.add_argument('--max-new-tokens', type=int, default=10) | |
parser.add_argument('--temperature', type=float, default=1.0) | |
parser.add_argument('--top-k', type=int, default=50) | |
parser.add_argument('--top-p', type=float, default=0.95) | |
parser.add_argument('--repetition-penalty', type=float, default=1.0) | |
parser.add_argument('--no-repeat-ngram-size', type=int, default=0) | |
parser.add_argument('--seed', type=int, default=0) | |
parser.add_argument('--do-sample', type=bool, default=True) | |
parser.add_argument('--use-cache', type=bool, default=True) | |
parser.add_argument('--trust-remote-code', type=bool, default=True) | |
parser.add_argument('--attn-impl', type=str, default='torch') | |
parser.add_argument('--threshold', type=float, default=0.3) | |
args = parser.parse_args() | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
with open(args.data_dir, 'r', encoding='utf8') as f: | |
url_data = f.read().strip().split('\n') | |
# split into 8 machine, and pick the part of machine_id | |
url_data = url_data[args.machine_id::args.num_machine] | |
# split url data into shards | |
url_data = [url_data[i::args.num_process] for i in range(args.num_process)] | |
dataloaders = [ | |
DataLoader( | |
OpenImageDataset(url_data[i]), | |
batch_size=1, | |
shuffle=False, | |
num_workers=4, | |
pin_memory=True, | |
persistent_workers=True, | |
drop_last=False, | |
prefetch_factor=4, | |
collate_fn=collate_fn | |
) | |
for i in range(args.num_process) | |
] | |
multiprocessing.set_start_method('spawn') | |
processes = [] | |
for shard_id, shard in enumerate(dataloaders): | |
p = Process( | |
target=save_tsv, | |
args=( | |
args, | |
shard_id, | |
shard, | |
torch.device('cuda:{}'.format(args.cuda_device[shard_id % len(args.cuda_device)])) | |
) | |
) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() | |
print('Done!') | |
if __name__ == '__main__': | |
main() | |