| import os | |
| from findfile import find_files, find_dir | |
| filter_key_words = [ | |
| ".py", | |
| ".md", | |
| "readme", | |
| "log", | |
| "result", | |
| "zip", | |
| ".state_dict", | |
| ".model", | |
| ".png", | |
| "acc_", | |
| "f1_", | |
| ".backup", | |
| ".bak", | |
| ] | |
| def detect_infer_dataset(dataset_path, task="apc"): | |
| dataset_file = [] | |
| if isinstance(dataset_path, str) and os.path.isfile(dataset_path): | |
| dataset_file.append(dataset_path) | |
| return dataset_file | |
| for d in dataset_path: | |
| if not os.path.exists(d): | |
| search_path = find_dir( | |
| os.getcwd(), | |
| [d, task, "dataset"], | |
| exclude_key=filter_key_words, | |
| disable_alert=False, | |
| ) | |
| dataset_file += find_files( | |
| search_path, | |
| [".inference", d], | |
| exclude_key=["train."] + filter_key_words, | |
| ) | |
| else: | |
| dataset_file += find_files( | |
| d, [".inference", task], exclude_key=["train."] + filter_key_words | |
| ) | |
| return dataset_file | |