File size: 10,831 Bytes
4014562
 
 
 
 
 
 
 
 
 
 
 
 
c4e69b4
 
4014562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4e69b4
4014562
 
 
 
 
 
 
 
d4218cc
ad64dda
d4218cc
 
ad64dda
d4218cc
4014562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4e69b4
 
 
4014562
 
 
 
 
 
 
 
 
 
 
 
0690c88
4014562
d4218cc
4014562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4218cc
 
4014562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4218cc
4014562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2034b44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import gradio as gr
from model import DemoModel
import json
import numpy as np
from fetch_prod import Scraper
from copy import deepcopy
from custom_label import format_labels_html
from pdb import set_trace as bp

    

model = DemoModel()
examples = json.load(open('amzn_examples.json'))
# cache = {x['text']: {'label': x['label']} for x in examples}
cache = json.load(open('cache.json'))
unseen_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/unseen_labels_split6500_2.txt')}
all_labels = {x.strip() for x in open('cleaned_code/datasets/Amzn13K/all_labels.txt')}

descriptions_visible = False

scraper = Scraper()

def format_gold_label_text(tex):
    if tex not in cache:
        return []
    if 'label' not in cache[tex]:
        return []
    return sorted([(x, 'Unseen' if x in unseen_labels else 'Seen' if x in all_labels else 'No Descriptions Available') for x in cache[tex]['label']], key = lambda x: (x[1], x[0]))[::-1]

def extract_topk(preds, is_unseen , k = 5):
    preds_clone = deepcopy(preds)
    preds_dic = preds_clone['preds']
    # bp()
    if is_unseen:
        preds_dic = {k:preds_dic[k] for k in set(preds_dic.keys()).intersection(unseen_labels)}  
        if 'label' in preds_clone:       
            preds_clone['label'] = list(set(preds_clone['label']).intersection(unseen_labels))
    else:
        if 'label' in preds_clone:       
            preds_clone['label'] = list(set(preds_clone['label']).intersection(all_labels))

    preds_dic = {k:v for k,v in sorted(preds_dic.items(), key = lambda x: -x[1])[:k]}
    # bp()
    preds_clone['preds'] = preds_dic
    return preds_clone

def classify(text, is_unseen):
    print(is_unseen)
    print('See this', text)
    if text in cache and 'preds' in cache[text]:
        print('Using Cached Result')
        return extract_topk(cache[text], is_unseen)#['preds']
    preds, descs = model.classify(text, unseen_labels if is_unseen else None)
    if text not in cache:
        cache[text] = dict()
    cache[text]['preds'] = preds 
    cache[text]['descs'] = descs 
    json.dump(cache, open('cache.json','w'), indent = 2)

    print(text, preds)
    # return preds
    return extract_topk(cache[text], is_unseen)

def scrape_click(url):
    out = scraper.get_product(url)
    if isinstance(out, str):
        if out == 'Invalid URL':
            raise gr.Error("Please enter a valid Amazon URL")
        else:
            print('Error Occured', out)
            raise gr.Error("Error Occured. Check the URL or try again later.")
            return
    
    text = out['description']
    if text not in cache:
        cache[text] = {'label': out['labels']}
    return gr.update(value = out['description'])

def get_random_example():
    return np.random.choice(examples)['text']

def toggle_descriptions_fn():
    print('Toggling descriptions visibility')
    global descriptions_visible
    descriptions_visible = not descriptions_visible
    return descriptions_visible

# for example in examples:
#     classify(example['text'], False)


with gr.Blocks(css="#warning {height: 100%}") as demo:
    with gr.Column():
        title = "<h1 style='margin-bottom: -10px; text-align: center'>SemSup-XC: Semantic Supervision for Extreme Classification</h1>"
        # gr.HTML(title)
        gr.Markdown(
            "<h1 style='text-align: center; margin-bottom: 1rem'>"
            + title
            + "</h1>"
        )

        description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://github.com/Pranjal2041' style='text-decoration:none' target='_blank'>Pranjal Aggarwal, </a> <a href='' style='text-decoration:none' target='_blank'>Ameet Deshpande, </a> <a href='' style='text-decoration:none' target='_blank'>Karthik Narasimhan </a> </p>" \
            + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'>  <a href='https://sites.google.com/view/semsup-xc/home' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/' target='_blank'>Paper</a> | <a href='https://github.com/princeton-nlp/semsup-xc' target='_blank'>Github Repo</a></p>" \
            + "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'>  \
                Extreme classification (XC) considers the scenario of predicting over a very large number of classes (thousands to millions), with real-world applications including serving search engine results, e-commerce product tagging, and news article classification. A real-life requirement in this domain is to predict from labels unseen during training(Zero-Shot), however there have been very little success in this domain. To this end, we propose SemSup-XC, a model that achieves state-of-the-art zero-shot (ZS) and few-shot (FS) performance on three extreme classification benchmarks spanning various domains. Instead of treating labels as class ids, our model learns from diverse descriptions of them, thereby attaining a more better understanding of the label space, evident from qualitative and quantitative results.   \
                </p>" \
        # gr.HTML(description)
        gr.Markdown(description)
        # head_html = gr.HTML('''
        #     <h1>
        #         SemSup-XC
        #     </h1>
        #     <p style='text-align: center;'>
        #         Ask stable diffusion for images by speaking (or singing πŸ€—) in your native language ! Try it in French πŸ˜‰
        #     </p>
            
        #     <p style='text-align: center;'>
        #         This demo is wired to the official SD Space β€’ Offered by Sylvain <a href='https://twitter.com/fffiloni' target='_blank'>@fffiloni</a> β€’ <img id='visitor-badge' alt='visitor badge' style='display: inline-block' /><br />
        #         β€”         
        #     </p>
        # ''')

        gr.Markdown(
        """ 
        <br>  
        <br>  
        Our model was trained on over 1 million product descriptions from Amazon on 6500 different categories.
        SemSup-XC can generalize to both seen and unseen labels.
        You can either use already available examples or enter your own text to classify.
        You can also fetch product descriptions by simply entering the product link, and classify categories on both seen and unseen labels.
        """
    )


        with gr.Tab(label = "Amazon"):

            # with gr.Row():
                # with gr.Column(scale = 4):
            url_textbox = gr.Textbox(
                label = 'URL for Amazon Product',
                lines=1,
                interactive=True
            )
                # with gr.Column(scale = 1):
            # scrape_btn = CustomButton(value = 'Fetch')#gr.Button('Fetch')
            scrape_btn = gr.Button('Fetch', variant='primary')
            
            # example_selection_dropdown = gr.Dropdown(choices = ["Example 1", "Example 2", "Example 3"], value = "Example 1", label = "Select an Example", interactive = True)
            text_box = gr.Textbox(
                label="Text to Classify",
                lines=4,
                interactive=True,
                value = get_random_example(),
                )
            # example_selection_dropdown.change(fn=lambda value: gr.update(value=value), inputs=example_selection_dropdown, outputs=text_box)
            
            with gr.Row():
                classify_btn = gr.Button("Classify", variant = 'primary')
                random_example_btn = gr.Button("Try Random")

            radio_btn = gr.Radio(choices = ['Unseen Labels', 'All Labels'], value = 'Unseen Labels', label = 'Classify on', interactive = True)

            # classified_labels_text = gr.Label(
            #     label = "Predicted Classes",
            #     visible = False,
            #     num_top_classes = 5,
            # )
            # classified_labels_text.style(True)
            with gr.Column(variant = 'panel'):
                label_html = gr.HTML('''''', visible=False)

                with gr.Row():
                    with gr.Column(scale = 8):                    
                        gold_labels = gr.HighlightedText(
                            label="Gold Labels",
                            value=[("Label 1", "Seen"), ("Label 2", "Seen"), ("Label 3", "Unseen"), ("Label 4", "No Descriptions Available")],
                            disabled=True,
                            visible=False
                        )
                    with gr.Column(scale = 1):
                        toggle_descriptions = gr.Button(
                            "Toggle Descriptions",
                            visible= False,
                            elem_id='warning'
                        ) 

            gold_labels.style(color_map = {'Seen': 'green', 'Unseen': 'blue', 'No Descriptions': 'gray'})
            # label_html2 = gr.HTML('''''', visible=False)



            # classify_btn.click(lambda value, is_unseen: gr.update(value = classify(value, is_unseen == 'Unseen Labels'), visible = True), inputs = [text_box, radio_btn], outputs=classified_labels_text)
            classify_btn.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible), visible = True), inputs = [text_box, radio_btn], outputs=label_html)
            classify_btn.click(lambda x: gr.update(visible=True), inputs = classify_btn, outputs = label_html)

            random_example_btn.click(lambda value: gr.update(value = get_random_example()), inputs= random_example_btn, outputs=text_box)
            random_example_btn.click(lambda value: (gr.update(visible = False), gr.update(visible = False), gr.update(visible = False)), inputs = random_example_btn, outputs=[label_html, gold_labels, toggle_descriptions])

            # radio_btn.change(lambda value: gr.update(visible))
            
            radio_btn.change(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = descriptions_visible)), inputs = [text_box, radio_btn], outputs=label_html)

            scrape_btn.click(lambda value : scrape_click(value), inputs = url_textbox, outputs = text_box)

            classify_btn.click(lambda x : gr.update(value = format_gold_label_text(x), visible = x in cache and 'label' in cache[x]), inputs = text_box, outputs = gold_labels)

            classify_btn.click(lambda x : gr.update(visible = True), inputs = text_box, outputs = toggle_descriptions)

            
            toggle_descriptions.click(lambda value, is_unseen: gr.update(value = format_labels_html(classify(value, is_unseen == 'Unseen Labels'), desc_is_visible = toggle_descriptions_fn()), visible = True), inputs = [text_box, radio_btn], outputs=label_html)


if __name__ == '__main__':
    demo.launch(share = False)