GLiClassReranker / scores_pipeline.py
BioMike's picture
fixed (#4)
9127367 verified
import os
import gradio as gr
from typing import List
from .initialize_models import multi_label_pipeline, st
example_1 =[
"I want to live in New York.",
'York is a cathedral city in North Yorkshire, England, with Roman origins',
'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.',
'New York, often called New York City (NYC),[b] is the most populous city in the United States',
"New York City is the third album by electronica group Brazilian Girls, released in 2008.",
"New York City was an American R&B vocal group.",
"New York City is an album by the Peter Malick Group featuring Norah Jones.",
"New York City: The Album is the debut studio album by American rapper Troy Ave. ",
'"New York City" is a song by British new wave band The Armoury Show',
]
example_2 = [
"Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.",
"TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.",
"UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.",
"AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.",
"Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.",
"SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.",
"Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.",
"TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.",
"Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes."
]
example_3 = [
"Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.",
"A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.",
"A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.",
"Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.",
"A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.",
"To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.",
"A 408 Request Timeout is returned when the client fails to send a complete request in time.",
"A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.",
"Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods."
]
example_4 = [
"A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.",
"Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.",
"Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.",
"Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.",
"Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.",
"Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.",
"Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.",
"Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.",
"Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever."
]
example_5 = [
"How can I set up a recurring payment for my monthly rent via online banking?",
"A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.",
"A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.",
"Wire transfers are typically one-off payments that do not recur automatically.",
"You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.",
"Bank-issued cashier’s checks are used for large payments but require manual setup each time.",
"To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).",
"Standing orders can be modified or canceled at any time via your online banking dashboard.",
"International transfers may incur additional fees and are not ideal for domestic rent payments."
]
examples = [
example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) -1) for example in [example_1, example_2, example_3, example_4, example_5]
]
def classification(*args) -> List[str]:
labels = [arg for arg in args[1:]]
labels = list(filter(None, labels))
query = args[0]
results = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True)
docs = []
scores = []
for predict in results:
docs.append(predict["label"])
scores.append(round(predict["score"], 2))
for _ in range(int(os.getenv("MAX_DOCS")) - len(docs)):
docs.append("")
scores.append("")
return docs + scores
with gr.Blocks(title="GLiClass-Reranker") as scores_pipeline:
inputs = []
outputs = []
query = gr.Textbox(
value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=10
)
submit_btn = gr.Button("Rerank")
inputs.append(query)
for i in range(int(os.getenv("MAX_DOCS"))):
with gr.Group():
doc_input = gr.Textbox(
value=examples[0][1+i],
label=f"Document {i}",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
score_output = gr.Textbox(
label=f"Score {i}",
placeholder="Score will appear here",
scale=2,
)
inputs.append(doc_input)
outputs.append(score_output)
outputs = inputs[1:] + outputs
examples = gr.Examples(
examples=examples,
fn=classification,
inputs=inputs,
outputs=outputs,
cache_examples=True,
)
submit_btn.click(
fn=classification, inputs=inputs, outputs=outputs
)