Spaces:
Runtime error
Runtime error
File size: 10,831 Bytes
4014562 c4e69b4 4014562 c4e69b4 4014562 d4218cc ad64dda d4218cc ad64dda d4218cc 4014562 c4e69b4 4014562 0690c88 4014562 d4218cc 4014562 d4218cc 4014562 d4218cc 4014562 2034b44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import gradio as gr
from model import DemoModel
import json
import numpy as np
from fetch_prod import Scraper
from copy import deepcopy
from custom_label import format_labels_html
from pdb import set_trace as bp
model = DemoModel()
examples = json.load(open('amzn_examples.json'))
# cache = {x['text']: {'label': x['label']} for x in examples}
cache = json.load(open('cache.json'))
unseen_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/unseen_labels_split6500_2.txt')}
all_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/all_labels.txt')}
descriptions_visible = False
scraper = Scraper()
def format_gold_label_text(tex):
if tex not in cache:
return []
if 'label' not in cache[tex]:
return []
return sorted([(x, 'Unseen' if x in unseen_labels else 'Seen' if x in all_labels else 'No Descriptions Available') for x in cache[tex]['label']], key = lambda x: (x[1], x[0]))[::-1]
def extract_topk(preds, is_unseen , k = 5):
preds_clone = deepcopy(preds)
preds_dic = preds_clone['preds']
# bp()
if is_unseen:
preds_dic = {k:preds_dic[k] for k in set(preds_dic.keys()).intersection(unseen_labels)}
if 'label' in preds_clone:
preds_clone['label'] = list(set(preds_clone['label']).intersection(unseen_labels))
else:
if 'label' in preds_clone:
preds_clone['label'] = list(set(preds_clone['label']).intersection(all_labels))
preds_dic = {k:v for k,v in sorted(preds_dic.items(), key = lambda x: -x[1])[:k]}
# bp()
preds_clone['preds'] = preds_dic
return preds_clone
def classify(text, is_unseen):
print(is_unseen)
print('See this', text)
if text in cache and 'preds' in cache[text]:
print('Using Cached Result')
return extract_topk(cache[text], is_unseen)#['preds']
preds, descs = model.classify(text, unseen_labels if is_unseen else None)
if text not in cache:
cache[text] = dict()
cache[text]['preds'] = preds
cache[text]['descs'] = descs
json.dump(cache, open('cache.json','w'), indent = 2)
print(text, preds)
# return preds
return extract_topk(cache[text], is_unseen)
def scrape_click(url):
out = scraper.get_product(url)
if isinstance(out, str):
if out == 'Invalid URL':
raise gr.Error("Please enter a valid Amazon URL")
else:
print('Error Occured', out)
raise gr.Error("Error Occured. Check the URL or try again later.")
return
text = out['description']
if text not in cache:
cache[text] = {'label': out['labels']}
return gr.update(value = out['description'])
def get_random_example():
return np.random.choice(examples)['text']
def toggle_descriptions_fn():
print('Toggling descriptions visibility')
global descriptions_visible
descriptions_visible = not descriptions_visible
return descriptions_visible
# for example in examples:
# classify(example['text'], False)
with gr.Blocks(css="#warning {height: 100%}") as demo:
with gr.Column():
title = "<h1 style='margin-bottom: -10px; text-align: center'>SemSup-XC: Semantic Supervision for Extreme Classification</h1>"
# gr.HTML(title)
gr.Markdown(
"<h1 style='text-align: center; margin-bottom: 1rem'>"
+ title
+ "</h1>"
)
description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://github.com/Pranjal2041' style='text-decoration:none' target='_blank'>Pranjal Aggarwal, </a> <a href='' style='text-decoration:none' target='_blank'>Ameet Deshpande, </a> <a href='' style='text-decoration:none' target='_blank'>Karthik Narasimhan </a> </p>" \
+ "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://sites.google.com/view/semsup-xc/home' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/' target='_blank'>Paper</a> | <a href='https://github.com/princeton-nlp/semsup-xc' target='_blank'>Github Repo</a></p>" \
+ "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \
Extreme classification (XC) considers the scenario of predicting over a very large number of classes (thousands to millions), with real-world applications including serving search engine results, e-commerce product tagging, and news article classification. A real-life requirement in this domain is to predict from labels unseen during training(Zero-Shot), however there have been very little success in this domain. To this end, we propose SemSup-XC, a model that achieves state-of-the-art zero-shot (ZS) and few-shot (FS) performance on three extreme classification benchmarks spanning various domains. Instead of treating labels as class ids, our model learns from diverse descriptions of them, thereby attaining a more better understanding of the label space, evident from qualitative and quantitative results. \
</p>" \
# gr.HTML(description)
gr.Markdown(description)
# head_html = gr.HTML('''
# <h1>
# SemSup-XC
# </h1>
# <p style='text-align: center;'>
# Ask stable diffusion for images by speaking (or singing π€) in your native language ! Try it in French π
# </p>
# <p style='text-align: center;'>
# This demo is wired to the official SD Space β’ Offered by Sylvain <a href='https://twitter.com/fffiloni' target='_blank'>@fffiloni</a> β’ <img id='visitor-badge' alt='visitor badge' style='display: inline-block' /><br />
# β
# </p>
# ''')
gr.Markdown(
"""
<br>
<br>
Our model was trained on over 1 million product descriptions from Amazon on 6500 different categories.
SemSup-XC can generalize to both seen and unseen labels.
You can either use already available examples or enter your own text to classify.
You can also fetch product descriptions by simply entering the product link, and classify categories on both seen and unseen labels.
"""
)
with gr.Tab(label = "Amazon"):
# with gr.Row():
# with gr.Column(scale = 4):
url_textbox = gr.Textbox(
label = 'URL for Amazon Product',
lines=1,
interactive=True
)
# with gr.Column(scale = 1):
# scrape_btn = CustomButton(value = 'Fetch')#gr.Button('Fetch')
scrape_btn = gr.Button('Fetch', variant='primary')
# example_selection_dropdown = gr.Dropdown(choices = ["Example 1", "Example 2", "Example 3"], value = "Example 1", label = "Select an Example", interactive = True)
text_box = gr.Textbox(
label="Text to Classify",
lines=4,
interactive=True,
value = get_random_example(),
)
# example_selection_dropdown.change(fn=lambda value: gr.update(value=value), inputs=example_selection_dropdown, outputs=text_box)
with gr.Row():
classify_btn = gr.Button("Classify", variant = 'primary')
random_example_btn = gr.Button("Try Random")
radio_btn = gr.Radio(choices = ['Unseen Labels', 'All Labels'], value = 'Unseen Labels', label = 'Classify on', interactive = True)
# classified_labels_text = gr.Label(
# label = "Predicted Classes",
# visible = False,
# num_top_classes = 5,
# )
# classified_labels_text.style(True)
with gr.Column(variant = 'panel'):
label_html = gr.HTML('''''', visible=False)
with gr.Row():
with gr.Column(scale = 8):
gold_labels = gr.HighlightedText(
label="Gold Labels",
value=[("Label 1", "Seen"), ("Label 2", "Seen"), ("Label 3", "Unseen"), ("Label 4", "No Descriptions Available")],
disabled=True,
visible=False
)
with gr.Column(scale = 1):
toggle_descriptions = gr.Button(
"Toggle Descriptions",
visible= False,
elem_id='warning'
)
gold_labels.style(color_map = {'Seen': 'green', 'Unseen': 'blue', 'No Descriptions': 'gray'})
# label_html2 = gr.HTML('''''', visible=False)
# classify_btn.click(lambda value, is_unseen: gr.update(value = classify(value, is_unseen == 'Unseen Labels'), visible = True), inputs = [text_box, radio_btn], outputs=classified_labels_text)
classify_btn.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible), visible = True), inputs = [text_box, radio_btn], outputs=label_html)
classify_btn.click(lambda x: gr.update(visible=True), inputs = classify_btn, outputs = label_html)
random_example_btn.click(lambda value: gr.update(value = get_random_example()), inputs= random_example_btn, outputs=text_box)
random_example_btn.click(lambda value: (gr.update(visible = False), gr.update(visible = False), gr.update(visible = False)), inputs = random_example_btn, outputs=[label_html, gold_labels, toggle_descriptions])
# radio_btn.change(lambda value: gr.update(visible))
radio_btn.change(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible)), inputs = [text_box, radio_btn], outputs=label_html)
scrape_btn.click(lambda value : scrape_click(value), inputs = url_textbox, outputs = text_box)
classify_btn.click(lambda x : gr.update(value = format_gold_label_text(x), visible = x in cache and 'label' in cache[x]), inputs = text_box, outputs = gold_labels)
classify_btn.click(lambda x : gr.update(visible = True), inputs = text_box, outputs = toggle_descriptions)
toggle_descriptions.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = toggle_descriptions_fn()), visible = True), inputs = [text_box, radio_btn], outputs=label_html)
if __name__ == '__main__':
demo.launch(share = False) |