import gradio as gr from model import DemoModel import json import numpy as np from fetch_prod import Scraper from copy import deepcopy from custom_label import format_labels_html from pdb import set_trace as bp model = DemoModel() examples = json.load(open('amzn_examples.json')) # cache = {x['text']: {'label': x['label']} for x in examples} cache = json.load(open('cache.json')) unseen_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/unseen_labels_split6500_2.txt')} all_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/all_labels.txt')} descriptions_visible = False scraper = Scraper() def format_gold_label_text(tex): if tex not in cache: return [] if 'label' not in cache[tex]: return [] return sorted([(x, 'Unseen' if x in unseen_labels else 'Seen' if x in all_labels else 'No Descriptions Available') for x in cache[tex]['label']], key = lambda x: (x[1], x[0]))[::-1] def extract_topk(preds, is_unseen , k = 5): preds_clone = deepcopy(preds) preds_dic = preds_clone['preds'] # bp() if is_unseen: preds_dic = {k:preds_dic[k] for k in set(preds_dic.keys()).intersection(unseen_labels)} if 'label' in preds_clone: preds_clone['label'] = list(set(preds_clone['label']).intersection(unseen_labels)) else: if 'label' in preds_clone: preds_clone['label'] = list(set(preds_clone['label']).intersection(all_labels)) preds_dic = {k:v for k,v in sorted(preds_dic.items(), key = lambda x: -x[1])[:k]} # bp() preds_clone['preds'] = preds_dic return preds_clone def classify(text, is_unseen): print(is_unseen) print('See this', text) if text in cache and 'preds' in cache[text]: print('Using Cached Result') return extract_topk(cache[text], is_unseen)#['preds'] preds, descs = model.classify(text, unseen_labels if is_unseen else None) if text not in cache: cache[text] = dict() cache[text]['preds'] = preds cache[text]['descs'] = descs json.dump(cache, open('cache.json','w'), indent = 2) print(text, preds) # return preds return extract_topk(cache[text], is_unseen) def scrape_click(url): out = scraper.get_product(url) if isinstance(out, str): if out == 'Invalid URL': raise gr.Error("Please enter a valid Amazon URL") else: print('Error Occured', out) raise gr.Error("Error Occured. Check the URL or try again later.") return text = out['description'] if text not in cache: cache[text] = {'label': out['labels']} return gr.update(value = out['description']) def get_random_example(): return np.random.choice(examples)['text'] def toggle_descriptions_fn(): print('Toggling descriptions visibility') global descriptions_visible descriptions_visible = not descriptions_visible return descriptions_visible # for example in examples: # classify(example['text'], False) with gr.Blocks(css="#warning {height: 100%}") as demo: with gr.Column(): title = "
Pranjal Aggarwal, Ameet Deshpande, Karthik Narasimhan
" \ + "Project Page | Paper | Github Repo
" \ + "\ Extreme classification (XC) considers the scenario of predicting over a very large number of classes (thousands to millions), with real-world applications including serving search engine results, e-commerce product tagging, and news article classification. A real-life requirement in this domain is to predict from labels unseen during training(Zero-Shot), however there have been very little success in this domain. To this end, we propose SemSup-XC, a model that achieves state-of-the-art zero-shot (ZS) and few-shot (FS) performance on three extreme classification benchmarks spanning various domains. Instead of treating labels as class ids, our model learns from diverse descriptions of them, thereby attaining a more better understanding of the label space, evident from qualitative and quantitative results. \
" \ # gr.HTML(description) gr.Markdown(description) # head_html = gr.HTML(''' ## Ask stable diffusion for images by speaking (or singing 🤗) in your native language ! Try it in French 😉 #
#
# This demo is wired to the official SD Space • Offered by Sylvain @fffiloni •
# —
#