Spaces:
Runtime error
Runtime error
import gradio as gr | |
from model import DemoModel | |
import json | |
import numpy as np | |
from fetch_prod import Scraper | |
from copy import deepcopy | |
from custom_label import format_labels_html | |
from pdb import set_trace as bp | |
model = DemoModel() | |
examples = json.load(open('amzn_examples.json')) | |
cache = {x['text']: {'label': x['label']} for x in examples} | |
unseen_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/unseen_labels_split6500_2.txt')} | |
all_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/all_labels.txt')} | |
descriptions_visible = False | |
scraper = Scraper() | |
def format_gold_label_text(tex): | |
if tex not in cache: | |
return [] | |
if 'label' not in cache[tex]: | |
return [] | |
return sorted([(x, 'Unseen' if x in unseen_labels else 'Seen' if x in all_labels else 'No Descriptions Available') for x in cache[tex]['label']], key = lambda x: (x[1], x[0]))[::-1] | |
def extract_topk(preds, is_unseen , k = 5): | |
preds_clone = deepcopy(preds) | |
preds_dic = preds_clone['preds'] | |
# bp() | |
if is_unseen: | |
preds_dic = {k:preds_dic[k] for k in set(preds_dic.keys()).intersection(unseen_labels)} | |
if 'label' in preds_clone: | |
preds_clone['label'] = list(set(preds_clone['label']).intersection(unseen_labels)) | |
else: | |
if 'label' in preds_clone: | |
preds_clone['label'] = list(set(preds_clone['label']).intersection(all_labels)) | |
preds_dic = {k:v for k,v in sorted(preds_dic.items(), key = lambda x: -x[1])[:k]} | |
# bp() | |
preds_clone['preds'] = preds_dic | |
return preds_clone | |
def classify(text, is_unseen): | |
print(is_unseen) | |
print('See this', text) | |
if text in cache and 'preds' in cache[text]: | |
print('Using Cached Result') | |
return extract_topk(cache[text], is_unseen)#['preds'] | |
preds, descs = model.classify(text, unseen_labels if is_unseen else None) | |
if text not in cache: | |
cache[text] = dict() | |
cache[text]['preds'] = preds | |
cache[text]['descs'] = descs | |
print(text, preds) | |
# return preds | |
return extract_topk(cache[text], is_unseen) | |
def scrape_click(url): | |
out = scraper.get_product(url) | |
if isinstance(out, str): | |
print('Error Occured', out) | |
return | |
text = out['description'] | |
if text not in cache: | |
cache[text] = {'label': out['labels']} | |
return gr.update(value = out['description']) | |
def get_random_example(): | |
return np.random.choice(examples)['text'] | |
def toggle_descriptions_fn(): | |
print('Toggling descriptions visibility') | |
global descriptions_visible | |
descriptions_visible = not descriptions_visible | |
return descriptions_visible | |
with gr.Blocks(css="#warning {height: 100%}") as demo: | |
with gr.Column(): | |
title = "<h1 style='margin-bottom: -10px; text-align: center'>SemSup-XC: Semantic Supervision for Extreme Classification</h1>" | |
# gr.HTML(title) | |
gr.Markdown( | |
"<h1 style='text-align: center; margin-bottom: 1rem'>" | |
+ title | |
+ "</h1>" | |
) | |
description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://github.com/Pranjal2041' style='text-decoration:none' target='_blank'>Pranjal Aggarwal, </a> <a href='' style='text-decoration:none' target='_blank'>Ameet Deshpande, </a> <a href='' style='text-decoration:none' target='_blank'>Karthik Narasimhan </a> </p>" \ | |
+ "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://sites.google.com/view/semsup-xc/home' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/' target='_blank'>Arxiv</a> | <a href='https://github.com/princeton-nlp/SemSup-XC' target='_blank'>Github Repo</a></p>" \ | |
+ "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \ | |
Extreme classification (XC) considers the scenario of predicting over a very large number of classes (thousands to millions), with real-world applications including serving search engine results, e-commerce product tagging, and news article classification. The zero-shot version of this task involves the addition of new categories at test time, requiring models to generalize to novel classes without additional training data (e.g. one may add a new class “fidget spinner” for e-commerce product tagging). In this paper, we develop SEMSUP-XC, a model that achieves state-of-the-art zero-shot (ZS) and few-shot (FS) performance on three extreme classification benchmarks spanning the domains of law, e-commerce, and Wikipedia. SEMSUP-XC builds upon the recently proposed framework of semantic supervision that uses semantic label descriptions to represent and generalize to classes (e.g., “fidget spinner” described as “A popular spinning toy intended as a stress reliever”). Specifically, we use a combination of contrastive learning, a hybrid lexico-semantic similarity module and automated description collection to train SEMSUP-XC efficiently over extremely large class spaces. SEMSUP-XC significantly outperforms baselines and state-of-the-art models on all three datasets, by up to 6-10 precision@1 points on zero-shot classification and >10 precision points on few-shot classification, with similar gains for recall@10 (3 for zero-shot and 2 for few-shot). Our ablation studies show the relative importance of various components and conclude the combined importance of the proposed architecture and automatically scraped descriptions with improvements up to 33 precision@1 points. Furthermore, qualitative analyses demonstrate SEMSUP-XC's better understanding of label space than other state-of-the-art models. \ | |
</p>" \ | |
# gr.HTML(description) | |
gr.Markdown(description) | |
# head_html = gr.HTML(''' | |
# <h1> | |
# SemSup-XC | |
# </h1> | |
# <p style='text-align: center;'> | |
# Ask stable diffusion for images by speaking (or singing 🤗) in your native language ! Try it in French 😉 | |
# </p> | |
# <p style='text-align: center;'> | |
# This demo is wired to the official SD Space • Offered by Sylvain <a href='https://twitter.com/fffiloni' target='_blank'>@fffiloni</a> • <img id='visitor-badge' alt='visitor badge' style='display: inline-block' /><br /> | |
# — | |
# </p> | |
# ''') | |
gr.Markdown( | |
""" | |
<br> | |
<br> | |
Our model was trained on over 1 million product descriptions from Amazon on 6500 different categories. | |
SemSup-XC can generalize to unseen labels. | |
You can also fetch product descriptions by simply entering the product link, and classify categories on both seen and unseen labels. | |
""" | |
) | |
with gr.Tab(label = "Amazon"): | |
# with gr.Row(): | |
# with gr.Column(scale = 4): | |
url_textbox = gr.Textbox( | |
label = 'URL for Amazon Product', | |
lines=1, | |
interactive=True | |
) | |
# with gr.Column(scale = 1): | |
# scrape_btn = CustomButton(value = 'Fetch')#gr.Button('Fetch') | |
scrape_btn = gr.Button('Fetch', variant='primary') | |
# example_selection_dropdown = gr.Dropdown(choices = ["Example 1", "Example 2", "Example 3"], value = "Example 1", label = "Select an Example", interactive = True) | |
text_box = gr.Textbox( | |
label="Text to Classify", | |
lines=4, | |
interactive=True, | |
value = get_random_example(), | |
) | |
# example_selection_dropdown.change(fn=lambda value: gr.update(value=value), inputs=example_selection_dropdown, outputs=text_box) | |
with gr.Row(): | |
classify_btn = gr.Button("Classify", variant = 'primary') | |
random_example_btn = gr.Button("Try Random") | |
radio_btn = gr.Radio(choices = ['Unseen Labels', 'All Labels'], value = 'Unseen Labels', label = 'Classify on', interactive = True) | |
# classified_labels_text = gr.Label( | |
# label = "Predicted Classes", | |
# visible = False, | |
# num_top_classes = 5, | |
# ) | |
# classified_labels_text.style(True) | |
with gr.Column(variant = 'panel'): | |
label_html = gr.HTML('''''', visible=False) | |
with gr.Row(): | |
with gr.Column(scale = 8): | |
gold_labels = gr.HighlightedText( | |
label="Gold Labels", | |
value=[("Label 1", "Seen"), ("Label 2", "Seen"), ("Label 3", "Unseen"), ("Label 4", "No Descriptions Available")], | |
disabled=True, | |
visible=False | |
) | |
with gr.Column(scale = 1): | |
toggle_descriptions = gr.Button( | |
"Toggle Descriptions", | |
visible= False, | |
elem_id='warning' | |
) | |
gold_labels.style(color_map = {'Seen': 'green', 'Unseen': 'blue', 'No Descriptions': 'gray'}) | |
# label_html2 = gr.HTML('''''', visible=False) | |
# classify_btn.click(lambda value, is_unseen: gr.update(value = classify(value, is_unseen == 'Unseen Labels'), visible = True), inputs = [text_box, radio_btn], outputs=classified_labels_text) | |
classify_btn.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible), visible = True), inputs = [text_box, radio_btn], outputs=label_html) | |
random_example_btn.click(lambda value: gr.update(value = get_random_example()), inputs= random_example_btn, outputs=text_box) | |
random_example_btn.click(lambda value: (gr.update(visible = False), gr.update(visible = False), gr.update(visible = False)), inputs = random_example_btn, outputs=[label_html, gold_labels, toggle_descriptions]) | |
# radio_btn.change(lambda value: gr.update(visible)) | |
radio_btn.change(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible)), inputs = [text_box, radio_btn], outputs=label_html) | |
scrape_btn.click(lambda value : scrape_click(value), inputs = url_textbox, outputs = text_box) | |
classify_btn.click(lambda x : gr.update(value = format_gold_label_text(x), visible = x in cache and 'label' in cache[x]), inputs = text_box, outputs = gold_labels) | |
classify_btn.click(lambda x : gr.update(visible = True), inputs = text_box, outputs = toggle_descriptions) | |
toggle_descriptions.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = toggle_descriptions_fn()), visible = True), inputs = [text_box, radio_btn], outputs=label_html) | |
if __name__ == '__main__': | |
demo.launch(share = True) |