Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # Set up custom environment before nearly anything else is imported | |
| # NOTE: this should be the first import (no not reorder) | |
| from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip | |
| import argparse | |
| import os | |
| import functools | |
| import io | |
| import datetime | |
| import itertools | |
| import json | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from collections import defaultdict | |
| from maskrcnn_benchmark.config import cfg | |
| from maskrcnn_benchmark.data import make_data_loader | |
| from maskrcnn_benchmark.engine.inference import inference, create_positive_dict, clean_name | |
| from maskrcnn_benchmark.modeling.detector import build_detection_model | |
| from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer | |
| from maskrcnn_benchmark.utils.collect_env import collect_env_info | |
| from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, all_gather | |
| from maskrcnn_benchmark.utils.logger import setup_logger | |
| from maskrcnn_benchmark.utils.miscellaneous import mkdir | |
| from maskrcnn_benchmark.utils.stats import get_model_complexity_info | |
| from omnilabeltools import OmniLabel, OmniLabelEval, visualize_image_sample | |
| import time | |
| import json | |
| import tempfile | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoTokenizer, CLIPTokenizerFast | |
| import omnilabeltools as olt | |
| from omnilabeltools import OmniLabel, OmniLabelEval | |
| import pdb | |
| import wandb | |
| from multiprocessing import Pool | |
| class LLM: | |
| def __init__(self, version, prompt_file = None, temp = 1.0): | |
| self.version = version | |
| self.prompt_file = prompt_file | |
| self.temp = temp | |
| with open(self.prompt_file, "r") as f: | |
| self.prompt = f.read() | |
| def __call__(self, entity): | |
| time.sleep(0.1) | |
| success = False | |
| fail_count = 0 | |
| if isinstance(entity, list): | |
| prompt = [self.prompt.replace("PROMPT", e) for e in entity] | |
| else: | |
| if self.version == "chat": | |
| raw_prompt = self.prompt.replace("PROMPT", entity) | |
| try: | |
| prompt = json.loads(raw_prompt) | |
| except: | |
| prompt = [{"role": "user", "content": raw_prompt}] | |
| else: | |
| prompt = self.prompt.replace("PROMPT", entity) | |
| while not success: | |
| try: | |
| if self.version == "chat": | |
| model = "gpt-3.5-turbo" | |
| response = openai.ChatCompletion.create( | |
| model=model, | |
| messages = prompt, | |
| temperature=self.temp, | |
| ) | |
| else: | |
| if self.version == "curie": | |
| model = "curie" | |
| else: | |
| model = "text-davinci-003" | |
| response = openai.Completion.create( | |
| model=model, | |
| prompt=prompt, | |
| temperature=self.temp, | |
| max_tokens=128, | |
| top_p=1, | |
| frequency_penalty=0.0, | |
| presence_penalty=0.0, | |
| ) | |
| success = True | |
| fail_count = 0 | |
| except Exception as e: | |
| print(f"Exception: {e}") | |
| time.sleep(0.1) | |
| fail_count += 1 | |
| if fail_count > 10: | |
| print("Too many failures") | |
| return "Too many failures" | |
| if isinstance(entity, list): | |
| if self.version == "chat": | |
| return [r["message"]["content"] for r in response["choices"]] | |
| else: | |
| return [r["text"] for r in response["choices"]] | |
| else: | |
| if self.version == "chat": | |
| return response["choices"][0]["message"]["content"] | |
| else: | |
| return response["choices"][0]["text"] | |
| def init_distributed_mode(args): | |
| """Initialize distributed training, if appropriate""" | |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: | |
| args.rank = int(os.environ["RANK"]) | |
| args.world_size = int(os.environ["WORLD_SIZE"]) | |
| args.gpu = int(os.environ["LOCAL_RANK"]) | |
| elif "SLURM_PROCID" in os.environ: | |
| args.rank = int(os.environ["SLURM_PROCID"]) | |
| args.gpu = args.rank % torch.cuda.device_count() | |
| else: | |
| print("Not using distributed mode") | |
| args.distributed = False | |
| return | |
| # args.distributed = True | |
| torch.cuda.set_device(args.gpu) | |
| args.dist_backend = "nccl" | |
| print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) | |
| dist.init_process_group( | |
| backend=args.dist_backend, | |
| init_method=args.dist_url, | |
| world_size=args.world_size, | |
| rank=args.rank, | |
| timeout=datetime.timedelta(0, 7200), | |
| ) | |
| dist.barrier() | |
| setup_for_distributed(args.rank == 0) | |
| def setup_for_distributed(is_master): | |
| """ | |
| This function disables printing when not in master process | |
| """ | |
| import builtins as __builtin__ | |
| builtin_print = __builtin__.print | |
| def print(*args, **kwargs): | |
| force = kwargs.pop("force", False) | |
| if is_master or force: | |
| builtin_print(*args, **kwargs) | |
| __builtin__.print = print | |
| def remove_full_stop(description_list): | |
| ret_list = [] | |
| for descript in description_list: | |
| if descript[-1] == '.': | |
| descript = descript[:-1] # remove '.' | |
| ret_list.append(descript) | |
| return ret_list | |
| def num_of_words(text): | |
| return len(text.split(' ')) | |
| def create_queries_and_maps(labels, label_list, tokenizer, additional_labels=None, cfg=None, center_nouns_length = None, override_tokens_positive = None): | |
| # Clean label list | |
| label_list = [clean_name(i) for i in label_list] | |
| # Form the query and get the mapping | |
| tokens_positive = [] | |
| start_i = 0 | |
| end_i = 0 | |
| objects_query = "Detect: " | |
| #objects_query = "" | |
| prefix_length = len(objects_query) | |
| # sep between tokens, follow training | |
| separation_tokens = cfg.DATASETS.SEPARATION_TOKENS | |
| caption_prompt = cfg.DATASETS.CAPTION_PROMPT | |
| use_caption_prompt = cfg.DATASETS.USE_CAPTION_PROMPT and caption_prompt is not None | |
| for _index, label in enumerate(label_list): | |
| if use_caption_prompt: | |
| objects_query += caption_prompt[_index]["prefix"] | |
| start_i = len(objects_query) | |
| if use_caption_prompt: | |
| objects_query += caption_prompt[_index]["name"] | |
| else: | |
| objects_query += label | |
| if "a kind of " in label: | |
| end_i = len(label.split(",")[0]) + start_i | |
| else: | |
| end_i = len(objects_query) | |
| tokens_positive.append([(start_i, end_i)]) # Every label has a [(start, end)] | |
| if use_caption_prompt: | |
| objects_query += caption_prompt[_index]["suffix"] | |
| if _index != len(label_list) - 1: | |
| objects_query += separation_tokens | |
| if additional_labels is not None: | |
| objects_query += separation_tokens | |
| for _index, label in enumerate(additional_labels): | |
| objects_query += label | |
| if _index != len(additional_labels) - 1: | |
| objects_query += separation_tokens | |
| # print(objects_query) | |
| if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": | |
| tokenized = tokenizer(objects_query, return_tensors="pt") | |
| elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base": | |
| tokenized = tokenizer(objects_query, return_tensors="pt") | |
| elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": | |
| tokenized = tokenizer( | |
| objects_query, max_length=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, truncation=True, return_tensors="pt" | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if override_tokens_positive is not None: | |
| new_tokens_positive = [] | |
| for override in override_tokens_positive: | |
| new_tokens_positive.append((override[0] + prefix_length, override[1] + prefix_length)) | |
| tokens_positive = [new_tokens_positive] # this is because we only have one label | |
| # Create the mapping between tokenized sentence and the original label | |
| # if one_hot: | |
| # positive_map_token_to_label, positive_map_label_to_token = create_one_hot_dict(labels, no_minus_one_for_one_hot=cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT) | |
| # else: | |
| positive_map_token_to_label, positive_map_label_to_token = create_positive_dict( | |
| tokenized, tokens_positive, labels=labels | |
| ) # from token position to original label | |
| return objects_query, positive_map_label_to_token | |
| def main(): | |
| parser = argparse.ArgumentParser(description="PyTorch Detection to Grounding Inference") | |
| parser.add_argument( | |
| "--config-file", | |
| default="configs/pretrain/glip_Swin_T_O365_GoldG.yaml", | |
| metavar="FILE", | |
| help="path to config file", | |
| ) | |
| parser.add_argument( | |
| "--weight", | |
| default=None, | |
| metavar="FILE", | |
| help="path to config file", | |
| ) | |
| parser.add_argument("--local_rank", type=int, default=0) | |
| parser.add_argument( | |
| "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER | |
| ) | |
| parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") | |
| parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") | |
| parser.add_argument("--task_config", default=None) | |
| parser.add_argument("--chunk_size", default=20, type=int, help="number of descriptions each time") | |
| parser.add_argument("--threshold", default=None, type=float, help="number of boxes stored in each run") | |
| parser.add_argument("--topk_per_eval", default=None, type=int, help="number of boxes stored in each run") | |
| parser.add_argument("--group_query", action="store_true", help="group query") | |
| parser.add_argument("--noun_phrase_file", default=None, type=str, help="noun phrase file") | |
| args = parser.parse_args() | |
| num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 | |
| distributed = num_gpus > 1 | |
| if distributed: | |
| # torch.cuda.set_device(args.local_rank) | |
| # torch.distributed.init_process_group( | |
| # backend="nccl", init_method="env://" | |
| # ) | |
| init_distributed_mode(args) | |
| print("Passed distributed init") | |
| cfg.local_rank = args.local_rank | |
| cfg.num_gpus = num_gpus | |
| cfg.merge_from_file(args.config_file) | |
| cfg.merge_from_list(args.opts) | |
| cfg.freeze() | |
| log_dir = cfg.OUTPUT_DIR | |
| if args.weight: | |
| log_dir = os.path.join(log_dir, "eval", os.path.splitext(os.path.basename(args.weight))[0]) | |
| if log_dir: | |
| mkdir(log_dir) | |
| logger = setup_logger("maskrcnn_benchmark", log_dir, get_rank()) | |
| logger.info(args) | |
| logger.info("Using {} GPUs".format(num_gpus)) | |
| logger.info(cfg) | |
| # logger.info("Collecting env info (might take some time)") | |
| # logger.info("\n" + collect_env_info()) | |
| device = cfg.MODEL.DEVICE | |
| cpu_device = torch.device("cpu") | |
| model = build_detection_model(cfg) | |
| model.to(device) | |
| # we currently disable this | |
| # params, flops = get_model_complexity_info(model, | |
| # (3, cfg.INPUT.MAX_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST), | |
| # input_constructor=lambda x: {'images': [torch.rand(x).cuda()]}) | |
| # print("FLOPs: {}, #Parameter: {}".format(params, flops)) | |
| checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) | |
| if args.weight: | |
| _ = checkpointer.load(args.weight, force=True) | |
| else: | |
| _ = checkpointer.load(cfg.MODEL.WEIGHT) | |
| if args.weight: | |
| weight_iter = os.path.splitext(os.path.basename(args.weight))[0].split("_")[-1] | |
| try: | |
| weight_iter = int(weight_iter) | |
| except: | |
| weight_iter = 1 | |
| else: | |
| weight_iter = 1 | |
| # get the wandb name | |
| train_wandb_name = os.path.basename(cfg.OUTPUT_DIR) | |
| eval_wandb_name = train_wandb_name + "_eval" + "_Fixed{}_Chunk{}".format(not cfg.DATASETS.LVIS_USE_NORMAL_AP, cfg.TEST.CHUNKED_EVALUATION) | |
| if is_main_process() and train_wandb_name != "__test__": | |
| api = wandb.Api() | |
| runs = api.runs('haroldli/language_det_eval') | |
| matched_run = None | |
| history = [] | |
| exclude_keys = ['_runtime', '_timestamp'] | |
| for run in runs: | |
| if run.name == eval_wandb_name and str(run._state) == "finished": | |
| print("run found", run.name) | |
| print(run.summary) | |
| matched_run = run | |
| run_his = matched_run.scan_history() | |
| #print([len(i) for i in run_his]) | |
| for stat in run_his: | |
| stat_i = {k: v for k, v in stat.items() if k not in exclude_keys and v is not None} | |
| if len(stat_i) > 1: | |
| history.append(stat_i) | |
| #matched_run.delete() | |
| break # only update one | |
| wandb_run = wandb.init( | |
| project = 'language_det_eval', | |
| job_type = 'evaluate', | |
| name = eval_wandb_name, | |
| ) | |
| #pprint(history) | |
| # exclude_keys = ['_step', '_runtime', '_timestamp'] | |
| # for stat in history: | |
| # wandb.log( | |
| # {k: v for k, v in stat.items() if k not in exclude_keys}, | |
| # step = stat['_step'], | |
| # ) | |
| else: | |
| wandb_run = None | |
| history = None | |
| print("weight_iter: ", weight_iter) | |
| print("train_wandb_name: ", train_wandb_name) | |
| print("eval_wandb_name: ", eval_wandb_name) | |
| # build tokenizer to process data | |
| # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base": | |
| tokenizer = AutoTokenizer.from_pretrained("roberta-base") | |
| elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": | |
| if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: | |
| tokenizer = CLIPTokenizerFast.from_pretrained( | |
| "openai/clip-vit-base-patch32", from_slow=True, mask_token="ðŁĴij</w>" | |
| ) | |
| else: | |
| tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True) | |
| else: | |
| tokenizer = None | |
| raise NotImplementedError | |
| ### inference & evaluation | |
| topk_per_eval = args.topk_per_eval | |
| threshold = args.threshold | |
| model.eval() | |
| chunk_size = args.chunk_size # num of texts each time | |
| if cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": | |
| class_plus = 1 | |
| else: | |
| class_plus = 0 | |
| task_config = args.task_config | |
| assert task_config is not None, "task_config should be assigned" | |
| cfg_ = cfg.clone() | |
| cfg_.defrost() | |
| cfg_.merge_from_file(task_config) | |
| cfg_.merge_from_list(args.opts) | |
| dataset_name = cfg_.DATASETS.TEST[0] | |
| output_folder = os.path.join(log_dir, "inference", dataset_name) | |
| if not os.path.exists(output_folder): | |
| mkdir(output_folder) | |
| data_loaders_val = make_data_loader(cfg_, is_train=False, is_distributed=distributed) | |
| _iterator = tqdm(data_loaders_val[0]) # only for the first test set | |
| predictions = [] | |
| # adhoclly | |
| # if "coco" in cfg_.DATASETS.TEST[0]: | |
| # gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_coco.json' | |
| # elif "oi_v5" in cfg_.DATASETS.TEST[0]: | |
| # gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv5.json' | |
| # elif "oi_v6" in cfg_.DATASETS.TEST[0]: | |
| # gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv6.json' | |
| # else: | |
| # assert(0) | |
| # omni_label = OmniLabel(path_json=gt_json) | |
| if args.noun_phrase_file is not None: | |
| try: | |
| noun_phrase = json.load(open(args.noun_phrase_file)) | |
| except: | |
| noun_phrase = {} | |
| print("No noun phrase file found, will generate one") | |
| llm = LLM(version="chat", prompt_file="tools/data_process/prompts/noun.v1.txt", temp=0.0) | |
| else: | |
| noun_phrase = {} | |
| # stats | |
| pos_rates = [] | |
| query_length = [] | |
| all_info = [] | |
| for iidx, batch in enumerate(_iterator): | |
| images, targets, image_ids, *_ = batch | |
| # import ipdb | |
| # ipdb.set_trace() | |
| images = images.to(device) | |
| text_queries = targets[0].get_field('inference_obj_descriptions') | |
| text_queries_ids = targets[0].get_field("inference_obj_description_ids") | |
| image_size = targets[0].size | |
| image_id = image_ids[0] | |
| # pdb.set_trace() | |
| #print(data_loaders_val[0].dataset.dataset_dicts[iidx]) | |
| #all_info.append(data_loaders_val[0].dataset.dataset_dicts[iidx]) | |
| # get the positive label if there is one | |
| try: | |
| positive_info = omni_label.get_image_sample(image_id) | |
| positive_instances = positive_info['instances'] | |
| positive_labels = [] | |
| for i in positive_instances: positive_labels.extend(i['description_ids']) | |
| positive_labels = list(set(positive_labels)) | |
| except: | |
| positive_labels = None | |
| des_id_start = 0 | |
| # rearrange the queries | |
| query_indexes = [i for i in range(len(text_queries_ids)) if num_of_words(text_queries[i]) > 2] | |
| cat_indexes = [i for i in range(len(text_queries_ids)) if num_of_words(text_queries[i]) <= 2] | |
| # rearrange the queries | |
| if args.group_query: | |
| text_queries_ids = [text_queries_ids[i] for i in query_indexes] + [text_queries_ids[i] for i in cat_indexes] | |
| text_queries = [text_queries[i] for i in query_indexes] + [text_queries[i] for i in cat_indexes] | |
| while des_id_start < len(text_queries_ids): | |
| # sinlge descriptions each time | |
| if args.group_query: | |
| if num_of_words(text_queries[des_id_start]) > 2: | |
| description_list = remove_full_stop(text_queries[des_id_start:des_id_start+8]) | |
| description_id_list = text_queries_ids[des_id_start:des_id_start+8] | |
| des_id_start += 8 | |
| else: | |
| description_list = remove_full_stop(text_queries[des_id_start:des_id_start+chunk_size]) | |
| description_id_list = text_queries_ids[des_id_start:des_id_start+chunk_size] | |
| des_id_start += chunk_size | |
| else: | |
| if num_of_words(text_queries[des_id_start]) > 2: | |
| _det_phrase = True | |
| description_list = remove_full_stop([text_queries[des_id_start]]) | |
| description_id_list = [text_queries_ids[des_id_start]] | |
| des_id_start += 1 | |
| else: | |
| _det_phrase = False | |
| description_list = remove_full_stop(text_queries[des_id_start:des_id_start+chunk_size]) | |
| description_id_list = text_queries_ids[des_id_start:des_id_start+chunk_size] | |
| des_id_start += chunk_size | |
| # create postive map, always use continuous labels starting from 1 | |
| continue_labels = np.arange(0, chunk_size) + class_plus | |
| if _det_phrase and args.noun_phrase_file is not None: | |
| # try to find the centern noun phrase | |
| center_noun = noun_phrase.get(description_list[0], None) | |
| if center_noun is None: | |
| center_noun = llm(description_list[0]) | |
| if len(center_noun) == 0: | |
| center_noun = description_list[0] # failed case | |
| noun_phrase[description_list[0]] = center_noun | |
| start = description_list[0].lower().find(center_noun.lower()) | |
| end = start + len(center_noun) | |
| override_tokens_positive = [(start, end)] | |
| print(description_list[0], center_noun, override_tokens_positive) | |
| cur_queries, positive_map_label_to_token = create_queries_and_maps(continue_labels, description_list, tokenizer, cfg=cfg, override_tokens_positive=override_tokens_positive) | |
| else: | |
| cur_queries, positive_map_label_to_token = create_queries_and_maps(continue_labels, description_list, tokenizer, cfg=cfg) | |
| set_description_id_list = set(description_id_list) | |
| # intersection between positive labels and current description ids | |
| if positive_labels is not None: | |
| pos_rate = len(set_description_id_list.intersection(set(positive_labels))) / len(set_description_id_list) | |
| pos_rates.append(pos_rate) | |
| query_length.append(len(set_description_id_list)) | |
| # print(cur_queries) | |
| with torch.no_grad(): | |
| output = model(images, captions=[cur_queries], positive_map=positive_map_label_to_token) | |
| output = output[0].to(cpu_device).convert(mode="xywh") | |
| output = output.resize(image_size) # to the oringinal scale | |
| # print(output) | |
| # import ipdb | |
| # ipdb.set_trace() | |
| # thresolding | |
| if threshold is not None: | |
| scores = output.get_field('scores') | |
| output = output[scores > threshold] | |
| # sorted by scores | |
| if topk_per_eval is not None: | |
| scores = output.get_field('scores') | |
| _, sortIndices = scores.sort(descending=True) | |
| output = output[sortIndices] | |
| # topk | |
| output = output[:topk_per_eval] | |
| # map continuous id to description id | |
| cont_ids_2_descript_ids = {i:v for i, v in enumerate(description_id_list)} | |
| pred_boxes = output.bbox | |
| pred_labels = output.get_field('labels') - class_plus # continuous ids, starting from 0 | |
| pred_scores = output.get_field('scores') | |
| # convert continuous id to description id | |
| for box_idx, box in enumerate(pred_boxes): | |
| predictions.append({ | |
| "image_id": image_id, | |
| "bbox": box.cpu().tolist(), | |
| "description_ids": [cont_ids_2_descript_ids[pred_labels[box_idx].item()]], | |
| "scores": [pred_scores[box_idx].item()], | |
| }) | |
| #print("pos_rate: %.2f"%(np.mean(pos_rates)), pos_rates) | |
| #print("query_length: %.2f"%(np.mean(query_length)), query_length) | |
| # draw a histogram of pos_rate | |
| plt.hist(pos_rates, bins=10) | |
| plt.savefig(os.path.join(output_folder, "pos_rate.png")) | |
| plt.close() | |
| if args.noun_phrase_file is not None: | |
| with open(args.noun_phrase_file, "w") as f: | |
| json.dump(noun_phrase, f, indent=4) | |
| # collect predictions from all GPUs | |
| synchronize() | |
| all_predictions = all_gather(predictions) | |
| all_predictions = list(itertools.chain(*all_predictions)) | |
| if not is_main_process(): | |
| return | |
| result_save_json = "%s_results.json"%(dataset_name) | |
| results_path = os.path.join(output_folder, result_save_json) | |
| print('Saving to', results_path) | |
| json.dump(all_predictions, open(results_path, 'w')) | |
| from maskrcnn_benchmark.config.paths_catalog import DatasetCatalog | |
| datasetMeta = DatasetCatalog.get(dataset_name) | |
| gt_path_json = datasetMeta['args']['ann_file'] | |
| # import ipdb | |
| # ipdb.set_trace() | |
| # evaluation | |
| gt = OmniLabel(gt_path_json) # load ground truth dataset | |
| dt = gt.load_res(results_path) # load prediction results | |
| ole = OmniLabelEval(gt, dt) | |
| # ole.params.resThrs = ... # set evaluation parameters as desired | |
| ole.evaluate() | |
| ole.accumulate() | |
| score = ole.summarize() | |
| # OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json | |
| #with open("tools/files/omnilabel_coco.json", "a") as f: | |
| # json.dump(all_info, f) | |
| if is_main_process(): | |
| if wandb_run is not None: | |
| # | |
| dataset_name = cfg.DATASETS.TEST[0] | |
| write_to_wandb_log(score, dataset_name, weight_iter, history) | |
| with open("{}/detailed.json".format(output_folder), "w") as f: | |
| json.dump(score, f) | |
| wandb_run.save("{}/detailed.json".format(output_folder)) | |
| print(score) | |
| def write_to_wandb_log(score, dataset_name, weight_iter, history): | |
| all_results = defaultdict(dict) | |
| exclude_keys = ['_step', '_runtime', '_timestamp'] | |
| if history is not None: | |
| for stat in history: | |
| all_results[stat['_step']].update({k: v for k, v in stat.items() if k not in exclude_keys}) | |
| result_dict = {} | |
| for score_i in score: | |
| if score_i["metric"]['metric'] == "AP" and score_i["metric"]['iou'] == "0.50:0.95" and score_i["metric"]['area'] == "all": | |
| result_dict[f"{dataset_name}_AP_{score_i['metric']['description']}"] = score_i['value'] | |
| #wandb.log({f"{dataset_name}_mAP_all": mAP_all, f"{dataset_name}_mAP_rare": mAP_rare, f"{dataset_name}_mAP_common": mAP_common, f"{dataset_name}_mAP_frequent": mAP_frequent}, step = weight_iter) | |
| all_results[weight_iter].update(result_dict) | |
| # sort all results | |
| max_key = max(all_results.keys()) | |
| for i in range(max_key + 1): | |
| if i in all_results: | |
| wandb.log(all_results[i], step = i) | |
| else: | |
| wandb.log({}, step = i) | |
| # for k in sorted(all_results.keys()): | |
| # # need to do consecutive logging | |
| # wandb.log(all_results[k], step = k) | |
| if __name__ == "__main__": | |
| main() | |
| ''' | |
| from omnilabeltools import OmniLabel, OmniLabelEval | |
| gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv5.json') # load ground truth dataset | |
| dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json") # load prediction results | |
| ole = OmniLabelEval(gt, dt) | |
| ole.evaluate() | |
| ole.accumulate() | |
| ole.summarize() | |
| gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_coco.json') # load ground truth dataset | |
| dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json") | |
| gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_object365.json') # load ground truth dataset | |
| dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json") | |
| ''' | |