Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import argparse | |
| import html | |
| import json | |
| import os | |
| import random | |
| import re | |
| from functools import partial | |
| from glob import glob | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import torchvision | |
| from tqdm import tqdm | |
| from .utils import IMG_EXTENSIONS | |
| tqdm.pandas() | |
| try: | |
| from pandarallel import pandarallel | |
| PANDA_USE_PARALLEL = True | |
| except ImportError: | |
| PANDA_USE_PARALLEL = False | |
| def apply(df, func, **kwargs): | |
| if PANDA_USE_PARALLEL: | |
| return df.parallel_apply(func, **kwargs) | |
| return df.progress_apply(func, **kwargs) | |
| TRAIN_COLUMNS = ["path", "text", "num_frames", "fps", "height", "width", "aspect_ratio", "resolution", "text_len"] | |
| # ====================================================== | |
| # --info | |
| # ====================================================== | |
| def get_video_length(cap, method="header"): | |
| assert method in ["header", "set"] | |
| if method == "header": | |
| length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| else: | |
| cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) | |
| length = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) | |
| return length | |
| def get_info(path): | |
| try: | |
| ext = os.path.splitext(path)[1].lower() | |
| if ext in IMG_EXTENSIONS: | |
| im = cv2.imread(path) | |
| if im is None: | |
| return 0, 0, 0, np.nan, np.nan | |
| height, width = im.shape[:2] | |
| num_frames, fps = 1, np.nan | |
| else: | |
| cap = cv2.VideoCapture(path) | |
| num_frames, height, width, fps = ( | |
| get_video_length(cap, method="header"), | |
| int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), | |
| int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | |
| float(cap.get(cv2.CAP_PROP_FPS)), | |
| ) | |
| hw = height * width | |
| aspect_ratio = height / width if width > 0 else np.nan | |
| return num_frames, height, width, aspect_ratio, fps, hw | |
| except: | |
| return 0, 0, 0, np.nan, np.nan, np.nan | |
| def get_video_info(path): | |
| try: | |
| vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") | |
| num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3] | |
| aspect_ratio = height / width | |
| fps = np.nan | |
| resolution = height * width | |
| return num_frames, height, width, aspect_ratio, fps, resolution | |
| except: | |
| return 0, 0, 0, np.nan, np.nan, np.nan | |
| # ====================================================== | |
| # --refine-llm-caption | |
| # ====================================================== | |
| LLAVA_PREFIX = [ | |
| "The video shows", | |
| "The video captures", | |
| "The video features", | |
| "The video depicts", | |
| "The video presents", | |
| "The video features", | |
| "The video is ", | |
| "In the video,", | |
| "The image shows", | |
| "The image captures", | |
| "The image features", | |
| "The image depicts", | |
| "The image presents", | |
| "The image features", | |
| "The image is ", | |
| "The image portrays", | |
| "In the image,", | |
| ] | |
| def remove_caption_prefix(caption): | |
| for prefix in LLAVA_PREFIX: | |
| if caption.startswith(prefix) or caption.startswith(prefix.lower()): | |
| caption = caption[len(prefix) :].strip() | |
| if caption[0].islower(): | |
| caption = caption[0].upper() + caption[1:] | |
| return caption | |
| return caption | |
| # ====================================================== | |
| # --merge-cmotion | |
| # ====================================================== | |
| CMOTION_TEXT = { | |
| "static": "The camera is static.", | |
| "dynamic": "The camera is moving.", | |
| "unknown": None, | |
| "zoom in": "The camera is zooming in.", | |
| "zoom out": "The camera is zooming out.", | |
| "pan left": "The camera is panning left.", | |
| "pan right": "The camera is panning right.", | |
| "tilt up": "The camera is tilting up.", | |
| "tilt down": "The camera is tilting down.", | |
| "pan/tilt": "The camera is panning.", | |
| } | |
| CMOTION_PROBS = { | |
| # hard-coded probabilities | |
| "static": 1.0, | |
| "dynamic": 1.0, | |
| "unknown": 0.0, | |
| "zoom in": 1.0, | |
| "zoom out": 1.0, | |
| "pan left": 1.0, | |
| "pan right": 1.0, | |
| "tilt up": 1.0, | |
| "tilt down": 1.0, | |
| "pan/tilt": 1.0, | |
| } | |
| def merge_cmotion(caption, cmotion): | |
| text = CMOTION_TEXT[cmotion] | |
| prob = CMOTION_PROBS[cmotion] | |
| if text is not None and random.random() < prob: | |
| caption = f"{caption} {text}" | |
| return caption | |
| # ====================================================== | |
| # --lang | |
| # ====================================================== | |
| def build_lang_detector(lang_to_detect): | |
| from lingua import Language, LanguageDetectorBuilder | |
| lang_dict = dict(en=Language.ENGLISH) | |
| assert lang_to_detect in lang_dict | |
| valid_lang = lang_dict[lang_to_detect] | |
| detector = LanguageDetectorBuilder.from_all_spoken_languages().with_low_accuracy_mode().build() | |
| def detect_lang(caption): | |
| confidence_values = detector.compute_language_confidence_values(caption) | |
| confidence = [x.language for x in confidence_values[:5]] | |
| if valid_lang not in confidence: | |
| return False | |
| return True | |
| return detect_lang | |
| # ====================================================== | |
| # --clean-caption | |
| # ====================================================== | |
| def basic_clean(text): | |
| import ftfy | |
| text = ftfy.fix_text(text) | |
| text = html.unescape(html.unescape(text)) | |
| return text.strip() | |
| BAD_PUNCT_REGEX = re.compile( | |
| r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" | |
| ) # noqa | |
| def clean_caption(caption): | |
| import urllib.parse as ul | |
| from bs4 import BeautifulSoup | |
| caption = str(caption) | |
| caption = ul.unquote_plus(caption) | |
| caption = caption.strip().lower() | |
| caption = re.sub("<person>", "person", caption) | |
| # urls: | |
| caption = re.sub( | |
| r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa | |
| "", | |
| caption, | |
| ) # regex for urls | |
| caption = re.sub( | |
| r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa | |
| "", | |
| caption, | |
| ) # regex for urls | |
| # html: | |
| caption = BeautifulSoup(caption, features="html.parser").text | |
| # @<nickname> | |
| caption = re.sub(r"@[\w\d]+\b", "", caption) | |
| # 31C0—31EF CJK Strokes | |
| # 31F0—31FF Katakana Phonetic Extensions | |
| # 3200—32FF Enclosed CJK Letters and Months | |
| # 3300—33FF CJK Compatibility | |
| # 3400—4DBF CJK Unified Ideographs Extension A | |
| # 4DC0—4DFF Yijing Hexagram Symbols | |
| # 4E00—9FFF CJK Unified Ideographs | |
| caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) | |
| caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) | |
| caption = re.sub(r"[\u3200-\u32ff]+", "", caption) | |
| caption = re.sub(r"[\u3300-\u33ff]+", "", caption) | |
| caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) | |
| caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) | |
| caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) | |
| ####################################################### | |
| # все виды тире / all types of dash --> "-" | |
| caption = re.sub( | |
| r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa | |
| "-", | |
| caption, | |
| ) | |
| # кавычки к одному стандарту | |
| caption = re.sub(r"[`´«»“”¨]", '"', caption) | |
| caption = re.sub(r"[‘’]", "'", caption) | |
| # " | |
| caption = re.sub(r""?", "", caption) | |
| # & | |
| caption = re.sub(r"&", "", caption) | |
| # ip adresses: | |
| caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) | |
| # article ids: | |
| caption = re.sub(r"\d:\d\d\s+$", "", caption) | |
| # \n | |
| caption = re.sub(r"\\n", " ", caption) | |
| # "#123" | |
| caption = re.sub(r"#\d{1,3}\b", "", caption) | |
| # "#12345.." | |
| caption = re.sub(r"#\d{5,}\b", "", caption) | |
| # "123456.." | |
| caption = re.sub(r"\b\d{6,}\b", "", caption) | |
| # filenames: | |
| caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) | |
| # | |
| caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" | |
| caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" | |
| caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT | |
| caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " | |
| # this-is-my-cute-cat / this_is_my_cute_cat | |
| regex2 = re.compile(r"(?:\-|\_)") | |
| if len(re.findall(regex2, caption)) > 3: | |
| caption = re.sub(regex2, " ", caption) | |
| caption = basic_clean(caption) | |
| caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 | |
| caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc | |
| caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 | |
| caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) | |
| caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) | |
| caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) | |
| caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) | |
| caption = re.sub(r"\bpage\s+\d+\b", "", caption) | |
| caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... | |
| caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) | |
| caption = re.sub(r"\b\s+\:\s+", r": ", caption) | |
| caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) | |
| caption = re.sub(r"\s+", " ", caption) | |
| caption.strip() | |
| caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) | |
| caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) | |
| caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) | |
| caption = re.sub(r"^\.\S+$", "", caption) | |
| return caption.strip() | |
| def text_preprocessing(text, use_text_preprocessing: bool = True): | |
| if use_text_preprocessing: | |
| # The exact text cleaning as was in the training stage: | |
| text = clean_caption(text) | |
| text = clean_caption(text) | |
| return text | |
| else: | |
| return text.lower().strip() | |
| # ====================================================== | |
| # load caption | |
| # ====================================================== | |
| def load_caption(path, ext): | |
| try: | |
| assert ext in ["json"] | |
| json_path = path.split(".")[0] + ".json" | |
| with open(json_path, "r") as f: | |
| data = json.load(f) | |
| caption = data["caption"] | |
| return caption | |
| except: | |
| return "" | |
| # ====================================================== | |
| # read & write | |
| # ====================================================== | |
| def read_file(input_path): | |
| if input_path.endswith(".csv"): | |
| return pd.read_csv(input_path) | |
| elif input_path.endswith(".parquet"): | |
| return pd.read_parquet(input_path) | |
| else: | |
| raise NotImplementedError(f"Unsupported file format: {input_path}") | |
| def save_file(data, output_path): | |
| output_dir = os.path.dirname(output_path) | |
| if not os.path.exists(output_dir) and output_dir != "": | |
| os.makedirs(output_dir) | |
| if output_path.endswith(".csv"): | |
| return data.to_csv(output_path, index=False) | |
| elif output_path.endswith(".parquet"): | |
| return data.to_parquet(output_path, index=False) | |
| else: | |
| raise NotImplementedError(f"Unsupported file format: {output_path}") | |
| def read_data(input_paths): | |
| data = [] | |
| input_name = "" | |
| input_list = [] | |
| for input_path in input_paths: | |
| input_list.extend(glob(input_path)) | |
| print("Input files:", input_list) | |
| for i, input_path in enumerate(input_list): | |
| assert os.path.exists(input_path) | |
| data.append(read_file(input_path)) | |
| input_name += os.path.basename(input_path).split(".")[0] | |
| if i != len(input_list) - 1: | |
| input_name += "+" | |
| print(f"Loaded {len(data[-1])} samples from {input_path}.") | |
| data = pd.concat(data, ignore_index=True, sort=False) | |
| print(f"Total number of samples: {len(data)}.") | |
| return data, input_name | |
| # ====================================================== | |
| # main | |
| # ====================================================== | |
| # To add a new method, register it in the main, parse_args, and get_output_path functions, and update the doc at /tools/datasets/README.md#documentation | |
| def main(args): | |
| # reading data | |
| data, input_name = read_data(args.input) | |
| # make difference | |
| if args.difference is not None: | |
| data_diff = pd.read_csv(args.difference) | |
| print(f"Difference csv contains {len(data_diff)} samples.") | |
| data = data[~data["path"].isin(data_diff["path"])] | |
| input_name += f"-{os.path.basename(args.difference).split('.')[0]}" | |
| print(f"Filtered number of samples: {len(data)}.") | |
| # make intersection | |
| if args.intersection is not None: | |
| data_new = pd.read_csv(args.intersection) | |
| print(f"Intersection csv contains {len(data_new)} samples.") | |
| cols_to_use = data_new.columns.difference(data.columns) | |
| cols_to_use = cols_to_use.insert(0, "path") | |
| data = pd.merge(data, data_new[cols_to_use], on="path", how="inner") | |
| print(f"Intersection number of samples: {len(data)}.") | |
| # train columns | |
| if args.train_column: | |
| all_columns = data.columns | |
| columns_to_drop = all_columns.difference(TRAIN_COLUMNS) | |
| data = data.drop(columns=columns_to_drop) | |
| # get output path | |
| output_path = get_output_path(args, input_name) | |
| # preparation | |
| if args.lang is not None: | |
| detect_lang = build_lang_detector(args.lang) | |
| if args.count_num_token == "t5": | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("DeepFloyd/t5-v1_1-xxl") | |
| # IO-related | |
| if args.load_caption is not None: | |
| assert "path" in data.columns | |
| data["text"] = apply(data["path"], load_caption, ext=args.load_caption) | |
| if args.info: | |
| info = apply(data["path"], get_info) | |
| ( | |
| data["num_frames"], | |
| data["height"], | |
| data["width"], | |
| data["aspect_ratio"], | |
| data["fps"], | |
| data["resolution"], | |
| ) = zip(*info) | |
| if args.video_info: | |
| info = apply(data["path"], get_video_info) | |
| ( | |
| data["num_frames"], | |
| data["height"], | |
| data["width"], | |
| data["aspect_ratio"], | |
| data["fps"], | |
| data["resolution"], | |
| ) = zip(*info) | |
| if args.ext: | |
| assert "path" in data.columns | |
| data = data[apply(data["path"], os.path.exists)] | |
| # filtering | |
| if args.remove_url: | |
| assert "text" in data.columns | |
| data = data[~data["text"].str.contains(r"(?P<url>https?://[^\s]+)", regex=True)] | |
| if args.lang is not None: | |
| assert "text" in data.columns | |
| data = data[data["text"].progress_apply(detect_lang)] # cannot parallelize | |
| if args.remove_empty_caption: | |
| assert "text" in data.columns | |
| data = data[data["text"].str.len() > 0] | |
| data = data[~data["text"].isna()] | |
| if args.remove_path_duplication: | |
| assert "path" in data.columns | |
| data = data.drop_duplicates(subset=["path"]) | |
| # processing | |
| if args.relpath is not None: | |
| data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath)) | |
| if args.abspath is not None: | |
| data["path"] = apply(data["path"], lambda x: os.path.join(args.abspath, x)) | |
| if args.merge_cmotion: | |
| data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1) | |
| if args.refine_llm_caption: | |
| assert "text" in data.columns | |
| data["text"] = apply(data["text"], remove_caption_prefix) | |
| if args.clean_caption: | |
| assert "text" in data.columns | |
| data["text"] = apply( | |
| data["text"], | |
| partial(text_preprocessing, use_text_preprocessing=True), | |
| ) | |
| if args.count_num_token is not None: | |
| assert "text" in data.columns | |
| data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"])) | |
| # sort | |
| if args.sort is not None: | |
| data = data.sort_values(by=args.sort, ascending=False) | |
| if args.sort_ascending is not None: | |
| data = data.sort_values(by=args.sort_ascending, ascending=True) | |
| # filtering | |
| if args.remove_empty_caption: | |
| assert "text" in data.columns | |
| data = data[data["text"].str.len() > 0] | |
| data = data[~data["text"].isna()] | |
| if args.fmin is not None: | |
| assert "num_frames" in data.columns | |
| data = data[data["num_frames"] >= args.fmin] | |
| if args.fmax is not None: | |
| assert "num_frames" in data.columns | |
| data = data[data["num_frames"] <= args.fmax] | |
| if args.hwmax is not None: | |
| if "resolution" not in data.columns: | |
| height = data["height"] | |
| width = data["width"] | |
| data["resolution"] = height * width | |
| data = data[data["resolution"] <= args.hwmax] | |
| if args.aesmin is not None: | |
| assert "aes" in data.columns | |
| data = data[data["aes"] >= args.aesmin] | |
| if args.matchmin is not None: | |
| assert "match" in data.columns | |
| data = data[data["match"] >= args.matchmin] | |
| if args.flowmin is not None: | |
| assert "flow" in data.columns | |
| data = data[data["flow"] >= args.flowmin] | |
| if args.remove_text_duplication: | |
| data = data.drop_duplicates(subset=["text"], keep="first") | |
| print(f"Filtered number of samples: {len(data)}.") | |
| # shard data | |
| if args.shard is not None: | |
| sharded_data = np.array_split(data, args.shard) | |
| for i in range(args.shard): | |
| output_path_part = output_path.split(".") | |
| output_path_s = ".".join(output_path_part[:-1]) + f"_{i}." + output_path_part[-1] | |
| save_file(sharded_data[i], output_path_s) | |
| print(f"Saved {len(sharded_data[i])} samples to {output_path_s}.") | |
| else: | |
| save_file(data, output_path) | |
| print(f"Saved {len(data)} samples to {output_path}.") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("input", type=str, nargs="+", help="path to the input dataset") | |
| parser.add_argument("--output", type=str, default=None, help="output path") | |
| parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"]) | |
| parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") | |
| parser.add_argument("--num-workers", type=int, default=None, help="number of workers") | |
| parser.add_argument("--seed", type=int, default=None, help="random seed") | |
| # special case | |
| parser.add_argument("--shard", type=int, default=None, help="shard the dataset") | |
| parser.add_argument("--sort", type=str, default=None, help="sort by column") | |
| parser.add_argument("--sort-ascending", type=str, default=None, help="sort by column (ascending order)") | |
| parser.add_argument("--difference", type=str, default=None, help="get difference from the dataset") | |
| parser.add_argument( | |
| "--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns" | |
| ) | |
| parser.add_argument("--train-column", action="store_true", help="only keep the train column") | |
| # IO-related | |
| parser.add_argument("--info", action="store_true", help="get the basic information of each video and image") | |
| parser.add_argument("--video-info", action="store_true", help="get the basic information of each video") | |
| parser.add_argument("--ext", action="store_true", help="check if the file exists") | |
| parser.add_argument( | |
| "--load-caption", type=str, default=None, choices=["json", "txt"], help="load the caption from json or txt" | |
| ) | |
| # path processing | |
| parser.add_argument("--relpath", type=str, default=None, help="modify the path to relative path by root given") | |
| parser.add_argument("--abspath", type=str, default=None, help="modify the path to absolute path by root given") | |
| # caption filtering | |
| parser.add_argument( | |
| "--remove-empty-caption", | |
| action="store_true", | |
| help="remove rows with empty caption", | |
| ) | |
| parser.add_argument("--remove-url", action="store_true", help="remove rows with url in caption") | |
| parser.add_argument("--lang", type=str, default=None, help="remove rows with other language") | |
| parser.add_argument("--remove-path-duplication", action="store_true", help="remove rows with duplicated path") | |
| parser.add_argument("--remove-text-duplication", action="store_true", help="remove rows with duplicated caption") | |
| # caption processing | |
| parser.add_argument("--refine-llm-caption", action="store_true", help="modify the caption generated by LLM") | |
| parser.add_argument( | |
| "--clean-caption", action="store_true", help="modify the caption according to T5 pipeline to suit training" | |
| ) | |
| parser.add_argument("--merge-cmotion", action="store_true", help="merge the camera motion to the caption") | |
| parser.add_argument( | |
| "--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption" | |
| ) | |
| # score filtering | |
| parser.add_argument("--fmin", type=int, default=None, help="filter the dataset by minimum number of frames") | |
| parser.add_argument("--fmax", type=int, default=None, help="filter the dataset by maximum number of frames") | |
| parser.add_argument("--hwmax", type=int, default=None, help="filter the dataset by maximum resolution") | |
| parser.add_argument("--aesmin", type=float, default=None, help="filter the dataset by minimum aes score") | |
| parser.add_argument("--matchmin", type=float, default=None, help="filter the dataset by minimum match score") | |
| parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score") | |
| return parser.parse_args() | |
| def get_output_path(args, input_name): | |
| if args.output is not None: | |
| return args.output | |
| name = input_name | |
| dir_path = os.path.dirname(args.input[0]) | |
| # sort | |
| if args.sort is not None: | |
| assert args.sort_ascending is None | |
| name += "_sort" | |
| if args.sort_ascending is not None: | |
| assert args.sort is None | |
| name += "_sort" | |
| # IO-related | |
| # for IO-related, the function must be wrapped in try-except | |
| if args.info: | |
| name += "_info" | |
| if args.video_info: | |
| name += "_vinfo" | |
| if args.ext: | |
| name += "_ext" | |
| if args.load_caption: | |
| name += f"_load{args.load_caption}" | |
| # path processing | |
| if args.relpath is not None: | |
| name += "_relpath" | |
| if args.abspath is not None: | |
| name += "_abspath" | |
| # caption filtering | |
| if args.remove_empty_caption: | |
| name += "_noempty" | |
| if args.remove_url: | |
| name += "_nourl" | |
| if args.lang is not None: | |
| name += f"_{args.lang}" | |
| if args.remove_path_duplication: | |
| name += "_noduppath" | |
| if args.remove_text_duplication: | |
| name += "_noduptext" | |
| # caption processing | |
| if args.refine_llm_caption: | |
| name += "_llm" | |
| if args.clean_caption: | |
| name += "_clean" | |
| if args.merge_cmotion: | |
| name += "_cmcaption" | |
| if args.count_num_token: | |
| name += "_ntoken" | |
| # score filtering | |
| if args.fmin is not None: | |
| name += f"_fmin{args.fmin}" | |
| if args.fmax is not None: | |
| name += f"_fmax{args.fmax}" | |
| if args.hwmax is not None: | |
| name += f"_hwmax{args.hwmax}" | |
| if args.aesmin is not None: | |
| name += f"_aesmin{args.aesmin}" | |
| if args.matchmin is not None: | |
| name += f"_matchmin{args.matchmin}" | |
| if args.flowmin is not None: | |
| name += f"_flowmin{args.flowmin}" | |
| output_path = os.path.join(dir_path, f"{name}.{args.format}") | |
| return output_path | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| if args.disable_parallel: | |
| PANDA_USE_PARALLEL = False | |
| if PANDA_USE_PARALLEL: | |
| if args.num_workers is not None: | |
| pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True) | |
| else: | |
| pandarallel.initialize(progress_bar=True) | |
| if args.seed is not None: | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| main(args) | |