Spaces:
Runtime error
Runtime error
import gradio as gr | |
from model import DemoModel | |
import json | |
import numpy as np | |
from fetch_prod import Scraper | |
from copy import deepcopy | |
from custom_label import format_labels_html | |
from pdb import set_trace as bp | |
model = DemoModel() | |
examples = json.load(open('amzn_examples.json')) | |
# cache = {x['text']: {'label': x['label']} for x in examples} | |
cache = json.load(open('cache.json')) | |
unseen_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/unseen_labels_split6500_2.txt')} | |
all_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/all_labels.txt')} | |
descriptions_visible = False | |
scraper = Scraper() | |
def format_gold_label_text(tex): | |
if tex not in cache: | |
return [] | |
if 'label' not in cache[tex]: | |
return [] | |
return sorted([(x, 'Unseen' if x in unseen_labels else 'Seen' if x in all_labels else 'No Descriptions Available') for x in cache[tex]['label']], key = lambda x: (x[1], x[0]))[::-1] | |
def extract_topk(preds, is_unseen , k = 5): | |
preds_clone = deepcopy(preds) | |
preds_dic = preds_clone['preds'] | |
# bp() | |
if is_unseen: | |
preds_dic = {k:preds_dic[k] for k in set(preds_dic.keys()).intersection(unseen_labels)} | |
if 'label' in preds_clone: | |
preds_clone['label'] = list(set(preds_clone['label']).intersection(unseen_labels)) | |
else: | |
if 'label' in preds_clone: | |
preds_clone['label'] = list(set(preds_clone['label']).intersection(all_labels)) | |
preds_dic = {k:v for k,v in sorted(preds_dic.items(), key = lambda x: -x[1])[:k]} | |
# bp() | |
preds_clone['preds'] = preds_dic | |
return preds_clone | |
def classify(text, is_unseen): | |
print(is_unseen) | |
print('See this', text) | |
if text in cache and 'preds' in cache[text]: | |
print('Using Cached Result') | |
return extract_topk(cache[text], is_unseen)#['preds'] | |
preds, descs = model.classify(text, unseen_labels if is_unseen else None) | |
if text not in cache: | |
cache[text] = dict() | |
cache[text]['preds'] = preds | |
cache[text]['descs'] = descs | |
json.dump(cache, open('cache.json','w'), indent = 2) | |
print(text, preds) | |
# return preds | |
return extract_topk(cache[text], is_unseen) | |
def scrape_click(url): | |
out = scraper.get_product(url) | |
if isinstance(out, str): | |
if out == 'Invalid URL': | |
raise gr.Error("Please enter a valid Amazon URL") | |
else: | |
print('Error Occured', out) | |
raise gr.Error("Error Occured. Check the URL or try again later.") | |
return | |
text = out['description'] | |
if text not in cache: | |
cache[text] = {'label': out['labels']} | |
return gr.update(value = out['description']) | |
def get_random_example(): | |
return np.random.choice(examples)['text'] | |
def toggle_descriptions_fn(): | |
print('Toggling descriptions visibility') | |
global descriptions_visible | |
descriptions_visible = not descriptions_visible | |
return descriptions_visible | |
# for example in examples: | |
# classify(example['text'], False) | |
with gr.Blocks(css="#warning {height: 100%}") as demo: | |
with gr.Column(): | |
title = "<h1 style='margin-bottom: -10px; text-align: center'>SemSup-XC: Semantic Supervision for Extreme Classification</h1>" | |
# gr.HTML(title) | |
gr.Markdown( | |
"<h1 style='text-align: center; margin-bottom: 1rem'>" | |
+ title | |
+ "</h1>" | |
) | |
description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://github.com/Pranjal2041' style='text-decoration:none' target='_blank'>Pranjal Aggarwal, </a> <a href='' style='text-decoration:none' target='_blank'>Ameet Deshpande, </a> <a href='' style='text-decoration:none' target='_blank'>Karthik Narasimhan </a> </p>" \ | |
+ "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://sites.google.com/view/semsup-xc/home' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/' target='_blank'>Paper</a> | <a href='https://github.com/princeton-nlp/semsup-xc' target='_blank'>Github Repo</a></p>" \ | |
+ "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \ | |
Extreme classification (XC) considers the scenario of predicting over a very large number of classes (thousands to millions), with real-world applications including serving search engine results, e-commerce product tagging, and news article classification. A real-life requirement in this domain is to predict from labels unseen during training(Zero-Shot), however there have been very little success in this domain. To this end, we propose SemSup-XC, a model that achieves state-of-the-art zero-shot (ZS) and few-shot (FS) performance on three extreme classification benchmarks spanning various domains. Instead of treating labels as class ids, our model learns from diverse descriptions of them, thereby attaining a more better understanding of the label space, evident from qualitative and quantitative results. \ | |
</p>" \ | |
# gr.HTML(description) | |
gr.Markdown(description) | |
# head_html = gr.HTML(''' | |
# <h1> | |
# SemSup-XC | |
# </h1> | |
# <p style='text-align: center;'> | |
# Ask stable diffusion for images by speaking (or singing π€) in your native language ! Try it in French π | |
# </p> | |
# <p style='text-align: center;'> | |
# This demo is wired to the official SD Space β’ Offered by Sylvain <a href='https://twitter.com/fffiloni' target='_blank'>@fffiloni</a> β’ <img id='visitor-badge' alt='visitor badge' style='display: inline-block' /><br /> | |
# β | |
# </p> | |
# ''') | |
gr.Markdown( | |
""" | |
<br> | |
<br> | |
Our model was trained on over 1 million product descriptions from Amazon on 6500 different categories. | |
SemSup-XC can generalize to both seen and unseen labels. | |
You can either use already available examples or enter your own text to classify. | |
You can also fetch product descriptions by simply entering the product link, and classify categories on both seen and unseen labels. | |
""" | |
) | |
with gr.Tab(label = "Amazon"): | |
# with gr.Row(): | |
# with gr.Column(scale = 4): | |
url_textbox = gr.Textbox( | |
label = 'URL for Amazon Product', | |
lines=1, | |
interactive=True | |
) | |
# with gr.Column(scale = 1): | |
# scrape_btn = CustomButton(value = 'Fetch')#gr.Button('Fetch') | |
scrape_btn = gr.Button('Fetch', variant='primary') | |
# example_selection_dropdown = gr.Dropdown(choices = ["Example 1", "Example 2", "Example 3"], value = "Example 1", label = "Select an Example", interactive = True) | |
text_box = gr.Textbox( | |
label="Text to Classify", | |
lines=4, | |
interactive=True, | |
value = get_random_example(), | |
) | |
# example_selection_dropdown.change(fn=lambda value: gr.update(value=value), inputs=example_selection_dropdown, outputs=text_box) | |
with gr.Row(): | |
classify_btn = gr.Button("Classify", variant = 'primary') | |
random_example_btn = gr.Button("Try Random") | |
radio_btn = gr.Radio(choices = ['Unseen Labels', 'All Labels'], value = 'Unseen Labels', label = 'Classify on', interactive = True) | |
# classified_labels_text = gr.Label( | |
# label = "Predicted Classes", | |
# visible = False, | |
# num_top_classes = 5, | |
# ) | |
# classified_labels_text.style(True) | |
with gr.Column(variant = 'panel'): | |
label_html = gr.HTML('''''', visible=False) | |
with gr.Row(): | |
with gr.Column(scale = 8): | |
gold_labels = gr.HighlightedText( | |
label="Gold Labels", | |
value=[("Label 1", "Seen"), ("Label 2", "Seen"), ("Label 3", "Unseen"), ("Label 4", "No Descriptions Available")], | |
disabled=True, | |
visible=False | |
) | |
with gr.Column(scale = 1): | |
toggle_descriptions = gr.Button( | |
"Toggle Descriptions", | |
visible= False, | |
elem_id='warning' | |
) | |
gold_labels.style(color_map = {'Seen': 'green', 'Unseen': 'blue', 'No Descriptions': 'gray'}) | |
# label_html2 = gr.HTML('''''', visible=False) | |
# classify_btn.click(lambda value, is_unseen: gr.update(value = classify(value, is_unseen == 'Unseen Labels'), visible = True), inputs = [text_box, radio_btn], outputs=classified_labels_text) | |
classify_btn.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible), visible = True), inputs = [text_box, radio_btn], outputs=label_html) | |
classify_btn.click(lambda x: gr.update(visible=True), inputs = classify_btn, outputs = label_html) | |
random_example_btn.click(lambda value: gr.update(value = get_random_example()), inputs= random_example_btn, outputs=text_box) | |
random_example_btn.click(lambda value: (gr.update(visible = False), gr.update(visible = False), gr.update(visible = False)), inputs = random_example_btn, outputs=[label_html, gold_labels, toggle_descriptions]) | |
# radio_btn.change(lambda value: gr.update(visible)) | |
radio_btn.change(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible)), inputs = [text_box, radio_btn], outputs=label_html) | |
scrape_btn.click(lambda value : scrape_click(value), inputs = url_textbox, outputs = text_box) | |
classify_btn.click(lambda x : gr.update(value = format_gold_label_text(x), visible = x in cache and 'label' in cache[x]), inputs = text_box, outputs = gold_labels) | |
classify_btn.click(lambda x : gr.update(visible = True), inputs = text_box, outputs = toggle_descriptions) | |
toggle_descriptions.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = toggle_descriptions_fn()), visible = True), inputs = [text_box, radio_btn], outputs=label_html) | |
if __name__ == '__main__': | |
demo.launch(share = False) |