app.py CHANGED
@@ -2,10 +2,10 @@ from dotenv import load_dotenv
2
  import gradio as gr
3
  load_dotenv()
4
 
5
- from interfaces import landing_interface, main_pipeline
6
 
7
- demo = gr.TabbedInterface([landing_interface, main_pipeline],
8
- ["Introduction", "Reranking"],
9
  title="GLiClass Reranker",
10
  theme=gr.themes.Base())
11
 
 
2
  import gradio as gr
3
  load_dotenv()
4
 
5
+ from interfaces import landing_interface, main_pipeline, compare_st
6
 
7
+ demo = gr.TabbedInterface([landing_interface, main_pipeline, compare_st],
8
+ ["Introduction", "Reranking", 'Compare'],
9
  title="GLiClass Reranker",
10
  theme=gr.themes.Base())
11
 
interfaces/.DS_Store ADDED
Binary file (6.15 kB). View file
 
interfaces/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .landing import landing_interface
2
- from .main_pipeline import main_pipeline
 
 
1
  from .landing import landing_interface
2
+ from .main_pipeline import main_pipeline
3
+ from .compare import compare_st
interfaces/compare.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from typing import List
5
+ import pandas as pd
6
+ from transformers import AutoTokenizer
7
+ from gliclass import GLiClassModel, ZeroShotClassificationPipeline
8
+ from sentence_transformers import CrossEncoder
9
+
10
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
+
12
+ model = GLiClassModel.from_pretrained(os.getenv("GLICLASS_MODEL_PATH")).eval().to(device)
13
+ tokenizer = AutoTokenizer.from_pretrained(os.getenv("GLICLASS_MODEL_PATH"))
14
+ multi_label_pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label',
15
+ device=device)
16
+ st = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
17
+ example_1 = [
18
+ "I want to live in New York.",
19
+ 'York is a cathedral city in North Yorkshire, England, with Roman origins',
20
+ 'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.',
21
+ 'New York, often called New York City (NYC),[b] is the most populous city in the United States',
22
+ "New York City is the third album by electronica group Brazilian Girls, released in 2008.",
23
+ "New York City was an American R&B vocal group.",
24
+ "New York City is an album by the Peter Malick Group featuring Norah Jones.",
25
+ "New York City: The Album is the debut studio album by American rapper Troy Ave. ",
26
+ '"New York City" is a song by British new wave band The Armoury Show',
27
+ ]
28
+
29
+ example_2 = [
30
+ "Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.",
31
+ "TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.",
32
+ "UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.",
33
+ "AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.",
34
+ "Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.",
35
+ "SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.",
36
+ "Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.",
37
+ "TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.",
38
+ "Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes."
39
+ ]
40
+
41
+ example_3 = [
42
+ "Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.",
43
+ "A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.",
44
+ "A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.",
45
+ "Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.",
46
+ "A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.",
47
+ "To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.",
48
+ "A 408 Request Timeout is returned when the client fails to send a complete request in time.",
49
+ "A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.",
50
+ "Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods."
51
+ ]
52
+
53
+ example_4 = [
54
+ "A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.",
55
+ "Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.",
56
+ "Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.",
57
+ "Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.",
58
+ "Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.",
59
+ "Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.",
60
+ "Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.",
61
+ "Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.",
62
+ "Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever."
63
+ ]
64
+
65
+ example_5 = [
66
+ "How can I set up a recurring payment for my monthly rent via online banking?",
67
+ "A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.",
68
+ "A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.",
69
+ "Wire transfers are typically one-off payments that do not recur automatically.",
70
+ "You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.",
71
+ "Bank-issued cashier’s checks are used for large payments but require manual setup each time.",
72
+ "To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).",
73
+ "Standing orders can be modified or canceled at any time via your online banking dashboard.",
74
+ "International transfers may incur additional fees and are not ideal for domestic rent payments."
75
+ ]
76
+
77
+
78
+ def compute_scores(*args):
79
+ labels = [arg for arg in args[1:]]
80
+ labels = list(filter(None, labels))
81
+ query = args[0]
82
+
83
+ ranks_st = st.rank(query, labels)
84
+ ranks_gliclass = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True)
85
+
86
+ docs_gliclass = []
87
+ scores_gliclass = []
88
+ docs_st = []
89
+ scores_st = []
90
+
91
+ label_to_text = {str(i): label for i, label in enumerate(labels)}
92
+
93
+
94
+ for predict in ranks_gliclass:
95
+ docs_gliclass.append(predict["label"])
96
+ scores_gliclass.append(round(predict["score"], 2))
97
+
98
+ for predict in ranks_st:
99
+ doc_id = predict["corpus_id"]
100
+ docs_st.append(label_to_text.get(str(doc_id), ""))
101
+ scores_st.append(round(predict["score"], 2))
102
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_st)):
103
+ docs_st.append("")
104
+ scores_st.append("")
105
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_gliclass)):
106
+ docs_gliclass.append("")
107
+ scores_gliclass.append("")
108
+
109
+ return docs_gliclass + scores_gliclass, docs_st + scores_st
110
+
111
+
112
+ def compute_table(*args):
113
+ gliclass_results, st_results = compute_scores(*args)
114
+ max_docs = int(os.getenv("MAX_DOCS"))
115
+ gliclass_labels = gliclass_results[:max_docs]
116
+ st_labels = st_results[:max_docs]
117
+ df = pd.DataFrame({
118
+ "Rank": list(range(1, max_docs + 1)),
119
+ "GLiClass Label": gliclass_labels,
120
+ "CrossEncoder Label": st_labels,
121
+ })
122
+
123
+ return df
124
+
125
+
126
+ examples = [
127
+ example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) - 1) for example in
128
+ [example_1, example_2, example_3, example_4, example_5]
129
+ ]
130
+
131
+ with gr.Blocks(title="GLiClass-Reranker") as compare_st:
132
+ inputs = []
133
+ query = gr.Textbox(
134
+ value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=4
135
+ )
136
+ labels = [gr.Textbox(value=label, label=f"Label {i+1}") for i, label in enumerate(examples[0][1:])]
137
+ submit_btn = gr.Button("Compare")
138
+ result_table = gr.Dataframe(headers=["Rank", "GLiClass Label", "CrossEncoder Label"],
139
+ label="Comparison Table",
140
+ interactive=False)
141
+
142
+ inputs = [query] + labels
143
+ submit_btn.click(fn=compute_table, inputs=inputs, outputs=result_table)
144
+