__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .landing import landing_interface
2
+ from .scores_pipeline import scores_pipeline
3
+ from .compare_pipeline import compare_pipeline
app.py CHANGED
@@ -2,10 +2,10 @@ from dotenv import load_dotenv
2
  import gradio as gr
3
  load_dotenv()
4
 
5
- from interfaces import landing_interface, main_pipeline
6
 
7
- demo = gr.TabbedInterface([landing_interface, main_pipeline],
8
- ["Introduction", "Reranking"],
9
  title="GLiClass Reranker",
10
  theme=gr.themes.Base())
11
 
 
2
  import gradio as gr
3
  load_dotenv()
4
 
5
+ from interfaces import landing_interface, scores_pipeline, compare_pipeline
6
 
7
+ demo = gr.TabbedInterface([landing_interface, scores_pipeline, compare_pipeline],
8
+ ["Introduction", "Reranking", 'Compare'],
9
  title="GLiClass Reranker",
10
  theme=gr.themes.Base())
11
 
compare_pipeline.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from .initialize_models import multi_label_pipeline, st
5
+
6
+ example_1 = [
7
+ "I want to live in New York.",
8
+ 'York is a cathedral city in North Yorkshire, England, with Roman origins',
9
+ 'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.',
10
+ 'New York, often called New York City (NYC),[b] is the most populous city in the United States',
11
+ "New York City is the third album by electronica group Brazilian Girls, released in 2008.",
12
+ "New York City was an American R&B vocal group.",
13
+ "New York City is an album by the Peter Malick Group featuring Norah Jones.",
14
+ "New York City: The Album is the debut studio album by American rapper Troy Ave. ",
15
+ '"New York City" is a song by British new wave band The Armoury Show',
16
+ ]
17
+
18
+ example_2 = [
19
+ "Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.",
20
+ "TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.",
21
+ "UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.",
22
+ "AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.",
23
+ "Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.",
24
+ "SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.",
25
+ "Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.",
26
+ "TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.",
27
+ "Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes."
28
+ ]
29
+
30
+ example_3 = [
31
+ "Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.",
32
+ "A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.",
33
+ "A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.",
34
+ "Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.",
35
+ "A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.",
36
+ "To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.",
37
+ "A 408 Request Timeout is returned when the client fails to send a complete request in time.",
38
+ "A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.",
39
+ "Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods."
40
+ ]
41
+
42
+ example_4 = [
43
+ "A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.",
44
+ "Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.",
45
+ "Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.",
46
+ "Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.",
47
+ "Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.",
48
+ "Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.",
49
+ "Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.",
50
+ "Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.",
51
+ "Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever."
52
+ ]
53
+
54
+ example_5 = [
55
+ "How can I set up a recurring payment for my monthly rent via online banking?",
56
+ "A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.",
57
+ "A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.",
58
+ "Wire transfers are typically one-off payments that do not recur automatically.",
59
+ "You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.",
60
+ "Bank-issued cashier’s checks are used for large payments but require manual setup each time.",
61
+ "To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).",
62
+ "Standing orders can be modified or canceled at any time via your online banking dashboard.",
63
+ "International transfers may incur additional fees and are not ideal for domestic rent payments."
64
+ ]
65
+
66
+ def compute_scores(*args):
67
+ labels = [arg for arg in args[1:]]
68
+ labels = list(filter(None, labels))
69
+ query = args[0]
70
+
71
+ ranks_st = st.rank(query, labels)
72
+ ranks_gliclass = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True)
73
+
74
+ docs_gliclass = []
75
+ scores_gliclass = []
76
+ docs_st = []
77
+ scores_st = []
78
+
79
+ label_to_text = {str(i): label for i, label in enumerate(labels)}
80
+
81
+
82
+ for predict in ranks_gliclass:
83
+ docs_gliclass.append(predict["label"])
84
+ scores_gliclass.append(round(predict["score"], 2))
85
+
86
+ for predict in ranks_st:
87
+ doc_id = predict["corpus_id"]
88
+ docs_st.append(label_to_text.get(str(doc_id), ""))
89
+ scores_st.append(round(predict["score"], 2))
90
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_st)):
91
+ docs_st.append("")
92
+ scores_st.append("")
93
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_gliclass)):
94
+ docs_gliclass.append("")
95
+ scores_gliclass.append("")
96
+
97
+ return docs_gliclass + scores_gliclass, docs_st + scores_st
98
+
99
+
100
+ def compute_table(*args):
101
+ gliclass_results, st_results = compute_scores(*args)
102
+ max_docs = int(os.getenv("MAX_DOCS"))
103
+ gliclass_labels = gliclass_results[:max_docs]
104
+ st_labels = st_results[:max_docs]
105
+ df = pd.DataFrame({
106
+ "Rank": list(range(1, max_docs + 1)),
107
+ "GLiClass Label": gliclass_labels,
108
+ "CrossEncoder Label": st_labels,
109
+ })
110
+
111
+ return df
112
+
113
+ examples = [
114
+ example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) - 1) for example in
115
+ [example_1, example_2, example_3, example_4, example_5]
116
+ ]
117
+
118
+ with gr.Blocks(title="GLiClass-Reranker") as compare_pipeline:
119
+ inputs = []
120
+ query = gr.Textbox(
121
+ value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=4
122
+ )
123
+ labels = [gr.Textbox(value=label, label=f"Label {i+1}") for i, label in enumerate(examples[0][1:])]
124
+ submit_btn = gr.Button("Compare")
125
+ result_table = gr.Dataframe(headers=["Rank", "GLiClass Label", "CrossEncoder Label"],
126
+ label="Comparison Table",
127
+ interactive=False)
128
+
129
+ inputs = [query] + labels
130
+ submit_btn.click(fn=compute_table, inputs=inputs, outputs=result_table)
131
+
initialize_models.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from sentence_transformers import CrossEncoder
4
+ from gliclass import GLiClassModel, ZeroShotClassificationPipeline
5
+ from transformers import AutoTokenizer
6
+
7
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
8
+ model = GLiClassModel.from_pretrained(os.getenv("GLICLASS_MODEL_PATH")).eval().to(device)
9
+ tokenizer = AutoTokenizer.from_pretrained(os.getenv("GLICLASS_MODEL_PATH"))
10
+ multi_label_pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label',
11
+ device=device)
12
+ st = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
13
+
14
+
interfaces/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .landing import landing_interface
2
- from .main_pipeline import main_pipeline
3
- from .compare import compare_st
 
1
  from .landing import landing_interface
2
+ from .scores_pipeline import scores_pipeline
3
+ from .compare_pipeline import compare_pipeline
interfaces/compare_pipeline.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from .initialize_models import multi_label_pipeline, st
5
+
6
+ example_1 = [
7
+ "I want to live in New York.",
8
+ 'York is a cathedral city in North Yorkshire, England, with Roman origins',
9
+ 'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.',
10
+ 'New York, often called New York City (NYC),[b] is the most populous city in the United States',
11
+ "New York City is the third album by electronica group Brazilian Girls, released in 2008.",
12
+ "New York City was an American R&B vocal group.",
13
+ "New York City is an album by the Peter Malick Group featuring Norah Jones.",
14
+ "New York City: The Album is the debut studio album by American rapper Troy Ave. ",
15
+ '"New York City" is a song by British new wave band The Armoury Show',
16
+ ]
17
+
18
+ example_2 = [
19
+ "Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.",
20
+ "TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.",
21
+ "UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.",
22
+ "AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.",
23
+ "Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.",
24
+ "SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.",
25
+ "Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.",
26
+ "TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.",
27
+ "Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes."
28
+ ]
29
+
30
+ example_3 = [
31
+ "Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.",
32
+ "A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.",
33
+ "A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.",
34
+ "Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.",
35
+ "A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.",
36
+ "To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.",
37
+ "A 408 Request Timeout is returned when the client fails to send a complete request in time.",
38
+ "A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.",
39
+ "Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods."
40
+ ]
41
+
42
+ example_4 = [
43
+ "A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.",
44
+ "Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.",
45
+ "Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.",
46
+ "Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.",
47
+ "Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.",
48
+ "Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.",
49
+ "Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.",
50
+ "Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.",
51
+ "Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever."
52
+ ]
53
+
54
+ example_5 = [
55
+ "How can I set up a recurring payment for my monthly rent via online banking?",
56
+ "A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.",
57
+ "A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.",
58
+ "Wire transfers are typically one-off payments that do not recur automatically.",
59
+ "You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.",
60
+ "Bank-issued cashier’s checks are used for large payments but require manual setup each time.",
61
+ "To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).",
62
+ "Standing orders can be modified or canceled at any time via your online banking dashboard.",
63
+ "International transfers may incur additional fees and are not ideal for domestic rent payments."
64
+ ]
65
+
66
+ def compute_scores(*args):
67
+ labels = [arg for arg in args[1:]]
68
+ labels = list(filter(None, labels))
69
+ query = args[0]
70
+
71
+ ranks_st = st.rank(query, labels)
72
+ ranks_gliclass = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True)
73
+
74
+ docs_gliclass = []
75
+ scores_gliclass = []
76
+ docs_st = []
77
+ scores_st = []
78
+
79
+ label_to_text = {str(i): label for i, label in enumerate(labels)}
80
+
81
+
82
+ for predict in ranks_gliclass:
83
+ docs_gliclass.append(predict["label"])
84
+ scores_gliclass.append(round(predict["score"], 2))
85
+
86
+ for predict in ranks_st:
87
+ doc_id = predict["corpus_id"]
88
+ docs_st.append(label_to_text.get(str(doc_id), ""))
89
+ scores_st.append(round(predict["score"], 2))
90
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_st)):
91
+ docs_st.append("")
92
+ scores_st.append("")
93
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_gliclass)):
94
+ docs_gliclass.append("")
95
+ scores_gliclass.append("")
96
+
97
+ return docs_gliclass + scores_gliclass, docs_st + scores_st
98
+
99
+
100
+ def compute_table(*args):
101
+ gliclass_results, st_results = compute_scores(*args)
102
+ max_docs = int(os.getenv("MAX_DOCS"))
103
+ gliclass_labels = gliclass_results[:max_docs]
104
+ st_labels = st_results[:max_docs]
105
+ df = pd.DataFrame({
106
+ "Rank": list(range(1, max_docs + 1)),
107
+ "GLiClass Label": gliclass_labels,
108
+ "CrossEncoder Label": st_labels,
109
+ })
110
+
111
+ return df
112
+
113
+ examples = [
114
+ example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) - 1) for example in
115
+ [example_1, example_2, example_3, example_4, example_5]
116
+ ]
117
+
118
+ with gr.Blocks(title="GLiClass-Reranker") as compare_pipeline:
119
+ inputs = []
120
+ query = gr.Textbox(
121
+ value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=4
122
+ )
123
+ labels = [gr.Textbox(value=label, label=f"Label {i+1}") for i, label in enumerate(examples[0][1:])]
124
+ submit_btn = gr.Button("Compare")
125
+ result_table = gr.Dataframe(headers=["Rank", "GLiClass Label", "CrossEncoder Label"],
126
+ label="Comparison Table",
127
+ interactive=False)
128
+
129
+ inputs = [query] + labels
130
+ submit_btn.click(fn=compute_table, inputs=inputs, outputs=result_table)
131
+
interfaces/initialize_models.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from sentence_transformers import CrossEncoder
4
+ from gliclass import GLiClassModel, ZeroShotClassificationPipeline
5
+ from transformers import AutoTokenizer
6
+
7
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
8
+ model = GLiClassModel.from_pretrained(os.getenv("GLICLASS_MODEL_PATH")).eval().to(device)
9
+ tokenizer = AutoTokenizer.from_pretrained(os.getenv("GLICLASS_MODEL_PATH"))
10
+ multi_label_pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label',
11
+ device=device)
12
+ st = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
13
+
14
+
interfaces/scores_pipeline.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from typing import List
4
+ from .initialize_models import multi_label_pipeline, st
5
+
6
+
7
+ example_1 =[
8
+ "I want to live in New York.",
9
+ 'York is a cathedral city in North Yorkshire, England, with Roman origins',
10
+ 'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.',
11
+ 'New York, often called New York City (NYC),[b] is the most populous city in the United States',
12
+ "New York City is the third album by electronica group Brazilian Girls, released in 2008.",
13
+ "New York City was an American R&B vocal group.",
14
+ "New York City is an album by the Peter Malick Group featuring Norah Jones.",
15
+ "New York City: The Album is the debut studio album by American rapper Troy Ave. ",
16
+ '"New York City" is a song by British new wave band The Armoury Show',
17
+ ]
18
+
19
+ example_2 = [
20
+ "Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.",
21
+ "TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.",
22
+ "UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.",
23
+ "AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.",
24
+ "Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.",
25
+ "SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.",
26
+ "Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.",
27
+ "TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.",
28
+ "Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes."
29
+ ]
30
+
31
+ example_3 = [
32
+ "Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.",
33
+ "A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.",
34
+ "A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.",
35
+ "Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.",
36
+ "A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.",
37
+ "To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.",
38
+ "A 408 Request Timeout is returned when the client fails to send a complete request in time.",
39
+ "A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.",
40
+ "Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods."
41
+ ]
42
+
43
+ example_4 = [
44
+ "A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.",
45
+ "Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.",
46
+ "Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.",
47
+ "Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.",
48
+ "Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.",
49
+ "Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.",
50
+ "Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.",
51
+ "Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.",
52
+ "Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever."
53
+ ]
54
+
55
+ example_5 = [
56
+ "How can I set up a recurring payment for my monthly rent via online banking?",
57
+ "A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.",
58
+ "A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.",
59
+ "Wire transfers are typically one-off payments that do not recur automatically.",
60
+ "You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.",
61
+ "Bank-issued cashier’s checks are used for large payments but require manual setup each time.",
62
+ "To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).",
63
+ "Standing orders can be modified or canceled at any time via your online banking dashboard.",
64
+ "International transfers may incur additional fees and are not ideal for domestic rent payments."
65
+ ]
66
+
67
+ examples = [
68
+ example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) -1) for example in [example_1, example_2, example_3, example_4, example_5]
69
+ ]
70
+
71
+
72
+ def classification(*args) -> List[str]:
73
+ labels = [arg for arg in args[1:]]
74
+ labels = list(filter(None, labels))
75
+ query = args[0]
76
+
77
+ results = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True)
78
+ docs = []
79
+ scores = []
80
+ for predict in results:
81
+ docs.append(predict["label"])
82
+ scores.append(round(predict["score"], 2))
83
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs)):
84
+ docs.append("")
85
+ scores.append("")
86
+ return docs + scores
87
+
88
+ with gr.Blocks(title="GLiClass-Reranker") as scores_pipeline:
89
+ inputs = []
90
+ outputs = []
91
+ query = gr.Textbox(
92
+ value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=10
93
+ )
94
+ submit_btn = gr.Button("Rerank")
95
+ inputs.append(query)
96
+ for i in range(int(os.getenv("MAX_DOCS"))):
97
+ with gr.Group():
98
+ doc_input = gr.Textbox(
99
+ value=examples[0][1+i],
100
+ label=f"Document {i}",
101
+ placeholder="Enter your labels here (comma separated)",
102
+ scale=2,
103
+ )
104
+ score_output = gr.Textbox(
105
+ label=f"Score {i}",
106
+ placeholder="Score will appear here",
107
+ scale=2,
108
+ )
109
+ inputs.append(doc_input)
110
+ outputs.append(score_output)
111
+ outputs = inputs[1:] + outputs
112
+ examples = gr.Examples(
113
+ examples=examples,
114
+ fn=classification,
115
+ inputs=inputs,
116
+ outputs=outputs,
117
+ cache_examples=True,
118
+ )
119
+
120
+ submit_btn.click(
121
+ fn=classification, inputs=inputs, outputs=outputs
122
+ )
landing.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ with open('materials/introduction.html', 'r', encoding='utf-8') as file:
5
+ html_description = file.read()
6
+
7
+ with gr.Blocks() as landing_interface:
8
+ gr.HTML(html_description)
9
+
10
+ with gr.Accordion("How to run this model locally", open=False):
11
+ gr.Markdown(
12
+ """
13
+ ## Installation
14
+ To use this model, you must install the GLiClass Python library:
15
+ ```
16
+ !pip install gliclass
17
+ ```
18
+
19
+ ## Usage
20
+ Once you've downloaded the GLiClass library, you can import the GLiClassModel and ZeroShotClassificationPipeline classes.
21
+ """
22
+ )
23
+ gr.Code(
24
+ '''
25
+ from gliclass import GLiClassModel, ZeroShotClassificationPipeline
26
+ from transformers import AutoTokenizer
27
+
28
+ model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1")
29
+ tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1")
30
+
31
+ pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
32
+
33
+ text = "One day I will see the world!"
34
+ labels = ["travel", "dreams", "sport", "science", "politics"]
35
+ results = pipeline(text, labels, threshold=0.5)[0] #because we have one text
36
+
37
+ for result in results:
38
+ print(result["label"], "=>", result["score"])
39
+ ''',
40
+ language="python",
41
+ )
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gliclass
2
  gradio
3
- dotenv
 
 
1
  gliclass
2
  gradio
3
+ dotenv
4
+ sentence-transformers
scores_pipeline.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from typing import List
4
+ from .initialize_models import multi_label_pipeline, st
5
+
6
+
7
+ example_1 =[
8
+ "I want to live in New York.",
9
+ 'York is a cathedral city in North Yorkshire, England, with Roman origins',
10
+ 'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.',
11
+ 'New York, often called New York City (NYC),[b] is the most populous city in the United States',
12
+ "New York City is the third album by electronica group Brazilian Girls, released in 2008.",
13
+ "New York City was an American R&B vocal group.",
14
+ "New York City is an album by the Peter Malick Group featuring Norah Jones.",
15
+ "New York City: The Album is the debut studio album by American rapper Troy Ave. ",
16
+ '"New York City" is a song by British new wave band The Armoury Show',
17
+ ]
18
+
19
+ example_2 = [
20
+ "Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.",
21
+ "TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.",
22
+ "UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.",
23
+ "AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.",
24
+ "Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.",
25
+ "SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.",
26
+ "Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.",
27
+ "TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.",
28
+ "Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes."
29
+ ]
30
+
31
+ example_3 = [
32
+ "Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.",
33
+ "A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.",
34
+ "A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.",
35
+ "Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.",
36
+ "A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.",
37
+ "To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.",
38
+ "A 408 Request Timeout is returned when the client fails to send a complete request in time.",
39
+ "A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.",
40
+ "Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods."
41
+ ]
42
+
43
+ example_4 = [
44
+ "A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.",
45
+ "Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.",
46
+ "Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.",
47
+ "Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.",
48
+ "Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.",
49
+ "Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.",
50
+ "Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.",
51
+ "Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.",
52
+ "Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever."
53
+ ]
54
+
55
+ example_5 = [
56
+ "How can I set up a recurring payment for my monthly rent via online banking?",
57
+ "A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.",
58
+ "A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.",
59
+ "Wire transfers are typically one-off payments that do not recur automatically.",
60
+ "You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.",
61
+ "Bank-issued cashier’s checks are used for large payments but require manual setup each time.",
62
+ "To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).",
63
+ "Standing orders can be modified or canceled at any time via your online banking dashboard.",
64
+ "International transfers may incur additional fees and are not ideal for domestic rent payments."
65
+ ]
66
+
67
+ examples = [
68
+ example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) -1) for example in [example_1, example_2, example_3, example_4, example_5]
69
+ ]
70
+
71
+
72
+ def classification(*args) -> List[str]:
73
+ labels = [arg for arg in args[1:]]
74
+ labels = list(filter(None, labels))
75
+ query = args[0]
76
+
77
+ results = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True)
78
+ docs = []
79
+ scores = []
80
+ for predict in results:
81
+ docs.append(predict["label"])
82
+ scores.append(round(predict["score"], 2))
83
+ for _ in range(int(os.getenv("MAX_DOCS")) - len(docs)):
84
+ docs.append("")
85
+ scores.append("")
86
+ return docs + scores
87
+
88
+ with gr.Blocks(title="GLiClass-Reranker") as scores_pipeline:
89
+ inputs = []
90
+ outputs = []
91
+ query = gr.Textbox(
92
+ value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=10
93
+ )
94
+ submit_btn = gr.Button("Rerank")
95
+ inputs.append(query)
96
+ for i in range(int(os.getenv("MAX_DOCS"))):
97
+ with gr.Group():
98
+ doc_input = gr.Textbox(
99
+ value=examples[0][1+i],
100
+ label=f"Document {i}",
101
+ placeholder="Enter your labels here (comma separated)",
102
+ scale=2,
103
+ )
104
+ score_output = gr.Textbox(
105
+ label=f"Score {i}",
106
+ placeholder="Score will appear here",
107
+ scale=2,
108
+ )
109
+ inputs.append(doc_input)
110
+ outputs.append(score_output)
111
+ outputs = inputs[1:] + outputs
112
+ examples = gr.Examples(
113
+ examples=examples,
114
+ fn=classification,
115
+ inputs=inputs,
116
+ outputs=outputs,
117
+ cache_examples=True,
118
+ )
119
+
120
+ submit_btn.click(
121
+ fn=classification, inputs=inputs, outputs=outputs
122
+ )