Spaces:
Running
Running
import os | |
import gradio as gr | |
import pandas as pd | |
from .initialize_models import multi_label_pipeline, st | |
example_1 = [ | |
"I want to live in New York.", | |
'York is a cathedral city in North Yorkshire, England, with Roman origins', | |
'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.', | |
'New York, often called New York City (NYC),[b] is the most populous city in the United States', | |
"New York City is the third album by electronica group Brazilian Girls, released in 2008.", | |
"New York City was an American R&B vocal group.", | |
"New York City is an album by the Peter Malick Group featuring Norah Jones.", | |
"New York City: The Album is the debut studio album by American rapper Troy Ave. ", | |
'"New York City" is a song by British new wave band The Armoury Show', | |
] | |
example_2 = [ | |
"Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.", | |
"TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.", | |
"UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.", | |
"AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.", | |
"Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.", | |
"SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.", | |
"Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.", | |
"TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.", | |
"Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes." | |
] | |
example_3 = [ | |
"Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.", | |
"A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.", | |
"A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.", | |
"Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.", | |
"A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.", | |
"To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.", | |
"A 408 Request Timeout is returned when the client fails to send a complete request in time.", | |
"A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.", | |
"Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods." | |
] | |
example_4 = [ | |
"A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.", | |
"Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.", | |
"Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.", | |
"Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.", | |
"Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.", | |
"Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.", | |
"Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.", | |
"Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.", | |
"Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever." | |
] | |
example_5 = [ | |
"How can I set up a recurring payment for my monthly rent via online banking?", | |
"A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.", | |
"A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.", | |
"Wire transfers are typically one-off payments that do not recur automatically.", | |
"You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.", | |
"Bank-issued cashier’s checks are used for large payments but require manual setup each time.", | |
"To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).", | |
"Standing orders can be modified or canceled at any time via your online banking dashboard.", | |
"International transfers may incur additional fees and are not ideal for domestic rent payments." | |
] | |
def compute_scores(*args): | |
labels = [arg for arg in args[1:]] | |
labels = list(filter(None, labels)) | |
query = args[0] | |
ranks_st = st.rank(query, labels) | |
ranks_gliclass = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True) | |
docs_gliclass = [] | |
scores_gliclass = [] | |
docs_st = [] | |
scores_st = [] | |
label_to_text = {str(i): label for i, label in enumerate(labels)} | |
for predict in ranks_gliclass: | |
docs_gliclass.append(predict["label"]) | |
scores_gliclass.append(round(predict["score"], 2)) | |
for predict in ranks_st: | |
doc_id = predict["corpus_id"] | |
docs_st.append(label_to_text.get(str(doc_id), "")) | |
scores_st.append(round(predict["score"], 2)) | |
for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_st)): | |
docs_st.append("") | |
scores_st.append("") | |
for _ in range(int(os.getenv("MAX_DOCS")) - len(docs_gliclass)): | |
docs_gliclass.append("") | |
scores_gliclass.append("") | |
return docs_gliclass + scores_gliclass, docs_st + scores_st | |
def compute_table(*args): | |
gliclass_results, st_results = compute_scores(*args) | |
max_docs = int(os.getenv("MAX_DOCS")) | |
labels = args[1:] | |
gliclass_labels = gliclass_results[:max_docs] | |
st_labels = st_results[:max_docs] | |
label_rank_gliclass = {label: rank + 1 for rank, label in enumerate(gliclass_labels) if label} | |
label_rank_st = {label: rank + 1 for rank, label in enumerate(st_labels) if label} | |
df = pd.DataFrame({ | |
"Document": labels, | |
"GLiClass Rank": [label_rank_gliclass.get(label, "") for label in labels], | |
"Cross-Encoder Rank": [label_rank_st.get(label, "") for label in labels], | |
}) | |
return df | |
examples = [ | |
example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) - 1) for example in | |
[example_1, example_2, example_3, example_4, example_5] | |
] | |
with gr.Blocks(title="GLiClass-Reranker") as compare_pipeline: | |
example_state = gr.State(value=examples) | |
inputs = [] | |
query = gr.Textbox( | |
value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=4 | |
) | |
labels = [gr.Textbox(value=label, label=f"Document {i+1}") for i, label in enumerate(examples[0][1:])] | |
submit_btn = gr.Button("Compare") | |
result_table = gr.Dataframe(headers=["Document", "GLiClass Rank", "Cross-Encoder Rank"], | |
label="Comparison Table", | |
interactive=False) | |
inputs = [query] + labels | |
submit_btn.click(fn=compute_table, inputs=inputs, outputs=result_table) | |