Spaces:
Running
Running
| ''' | |
| This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas) | |
| Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory. | |
| Find them here- [https://github.com/google-research-datasets/conceptual-captions] | |
| ''' | |
| import sys | |
| import os | |
| from datetime import datetime | |
| import pandas as pd | |
| import contexttimer | |
| from urllib.request import urlopen | |
| import requests | |
| from PIL import Image | |
| import torch | |
| from torchvision.transforms import functional as TF | |
| from multiprocessing import Pool | |
| from tqdm import tqdm | |
| import logging | |
| import sys | |
| # Setup | |
| logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO) | |
| requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning) | |
| if len(sys.argv) != 3: | |
| print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training") | |
| exit(1) | |
| # Load data | |
| print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}') | |
| with contexttimer.Timer(prefix="Loading from tsv"): | |
| df = pd.read_csv(sys.argv[1], delimiter='\t', header=None) | |
| url_to_idx_map = {url: index for index, caption, url in df.itertuples()} | |
| print(f'Loaded {len(url_to_idx_map)} urls') | |
| base_dir = os.path.join(os.getcwd(), sys.argv[2]) | |
| def process(item): | |
| url, image_id = item | |
| try: | |
| base_url = os.path.basename(url) # extract base url | |
| stem, ext = os.path.splitext(base_url) # split into stem and extension | |
| filename = f'{image_id:08d}---{stem}.jpg' # create filename | |
| filepath = os.path.join(base_dir, filename) # concat to get filepath | |
| if not os.path.isfile(filepath): | |
| req = requests.get(url, stream=True, timeout=1, verify=False).raw | |
| image = Image.open(req).convert('RGB') | |
| if min(image.size) > 512: | |
| image = TF.resize(image, size=512, interpolation=Image.LANCZOS) | |
| image.save(filepath) # save PIL image | |
| except Exception as e: | |
| logging.info(" ".join(repr(e).splitlines())) | |
| logging.error(url) | |
| list_of_items = list(url_to_idx_map.items()) | |
| print(len(list_of_items)) | |
| with Pool(128) as p: | |
| r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items))) | |
| print('DONE') |