Spaces:
Running
Running
| # Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas) | |
| #%% | |
| import sys | |
| import os | |
| from datetime import datetime | |
| import pandas as pd | |
| import contexttimer | |
| from urllib.request import urlopen | |
| import requests | |
| from PIL import Image | |
| import torch | |
| from torchvision.transforms import functional as TF | |
| from multiprocessing import Pool | |
| from tqdm import tqdm | |
| import logging | |
| # Setup | |
| logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO) | |
| requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning) | |
| # # For downloading SVG images (I can't get this to work) | |
| # from io import BytesIO | |
| # import cairosvg | |
| #%% | |
| # Load data | |
| print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}') | |
| with contexttimer.Timer(prefix="Loading from tsv"): | |
| df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None) | |
| url_to_idx_map = {url: index for index, url, caption in df.itertuples()} | |
| print(f'Loaded {len(url_to_idx_map)} urls') | |
| #%% | |
| df.head() | |
| #%% | |
| # Note: it seems that there are no SVG images | |
| df.sample(10000)[1].str.contains('.svg').sum() | |
| #%% | |
| # Resize function | |
| def resize(img): | |
| max_size_of_short_side = 512 | |
| if min(img.size) > max_size_of_short_side: | |
| img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS) | |
| return img | |
| base_dir = os.path.join(os.getcwd(), 'images') | |
| def process(item): | |
| url, image_id = item | |
| try: | |
| base_url = os.path.basename(url) # extract base url | |
| stem, ext = os.path.splitext(base_url) # split into stem and extension | |
| filename = f'{image_id:08d}---{stem}.jpg' # create filename | |
| filepath = os.path.join(base_dir, filename) # concat to get filepath | |
| if not os.path.isfile(filepath): | |
| # if filepath.endswith('.svg'): | |
| # raise NotImplementedError() | |
| # image_bytes = BytesIO() # create a bytestream | |
| # cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image | |
| # else: | |
| req = requests.get(url, stream=True, timeout=1, verify=False).raw | |
| image = Image.open(req).convert('RGB') | |
| if min(image.size) > 512: | |
| image = TF.resize(image, size=512, interpolation=Image.LANCZOS) | |
| # image = resize(image) # resize PIL image | |
| image.save(filepath) # save PIL image | |
| except Exception as e: | |
| logging.info(" ".join(repr(e).splitlines())) | |
| logging.error(url) | |
| #%% | |
| #for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))): | |
| # process(item) | |
| # if i > 100: | |
| # break | |
| # Use multiprocessing for speed | |
| list_of_items = list(url_to_idx_map.items()) | |
| print(len(list_of_items)) | |
| list_of_items = list_of_items[10_000_000:] | |
| print(len(list_of_items)) | |
| with Pool(128) as p: | |
| r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items))) | |
| print('DONE') | |