Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from typing import List | |
from transformers import AutoTokenizer | |
from gliclass import GLiClassModel, ZeroShotClassificationPipeline | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
model = GLiClassModel.from_pretrained(os.getenv("GLICLASS_MODEL_PATH")).eval().to(device) | |
tokenizer = AutoTokenizer.from_pretrained(os.getenv("GLICLASS_MODEL_PATH")) | |
multi_label_pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device=device) | |
example_1 =[ | |
"I want to live in New York.", | |
'York is a cathedral city in North Yorkshire, England, with Roman origins', | |
'San Francisco,[23] officially the City and County of San Francisco, is a commercial, financial, and cultural center within Northern California, United States.', | |
'New York, often called New York City (NYC),[b] is the most populous city in the United States', | |
"New York City is the third album by electronica group Brazilian Girls, released in 2008.", | |
"New York City was an American R&B vocal group.", | |
"New York City is an album by the Peter Malick Group featuring Norah Jones.", | |
"New York City: The Album is the debut studio album by American rapper Troy Ave. ", | |
'"New York City" is a song by British new wave band The Armoury Show', | |
] | |
example_2 = [ | |
"Looking for waterproof hiking boots that can handle freezing temperatures and rugged terrain.", | |
"TrailMaster X200 – waterproof boots with Vibram Arctic Grip soles, rated for -20°C and rocky paths.", | |
"UrbanStep Sneakers – stylish and breathable, not designed for rugged use or cold weather.", | |
"AlpineShield GTX – Gore-Tex lining, insulated to -15°C, ideal for mountain hiking.", | |
"Desert Trek Sandals – open-toe design, breathable and lightweight, not waterproof.", | |
"SummitPro Winter Boots – fleece-lined, waterproof up to ankle depth, tested to -5°C.", | |
"Marathon Lite – road-running shoes with shock-absorbing soles, non-waterproof.", | |
"TrailMaster X100 – waterproof boots with basic insulation, effective down to 0°C.", | |
"Climber Pro GTX – reinforced toe cap, Gore-Tex membrane, insulated to -20°C, certified for alpine routes." | |
] | |
example_3 = [ | |
"Our users are reporting 504 Gateway Timeout errors when accessing the app during peak hours.", | |
"A 504 Gateway Timeout indicates that a server did not receive a timely response from another server upstream.", | |
"A 502 Bad Gateway occurs when the server, acting as a gateway, receives an invalid response from the upstream server.", | |
"Common causes of 504 errors include high server load, network congestion, or misconfigured backend timeouts.", | |
"A 403 Forbidden error suggests that the server is refusing to authorize the request, often due to permissions.", | |
"To resolve 504 errors, check server logs, backend service availability, and increase timeout settings if necessary.", | |
"A 408 Request Timeout is returned when the client fails to send a complete request in time.", | |
"A 500 Internal Server Error is a generic error indicating that the server encountered an unexpected condition.", | |
"Network latency monitoring tools can help identify bottlenecks that may cause 504 errors during high traffic periods." | |
] | |
example_4 = [ | |
"A 45-year-old male presents with persistent cough, night sweats, low-grade fever, and weight loss over 3 months.", | |
"Lung cancer can cause cough and weight loss; however, it often includes hemoptysis and may show a solitary mass on imaging.", | |
"Bronchiectasis is characterized by chronic productive cough and recurrent infections but usually lacks significant weight loss.", | |
"Pneumonia presents acutely with high fever, productive cough, and may show lobar consolidation on imaging.", | |
"Sarcoidosis may cause cough and weight loss, with bilateral hilar lymphadenopathy seen on chest X-ray.", | |
"Tuberculosis typically presents with chronic cough, night sweats, weight loss, and may show upper lobe infiltrates on chest X-ray.", | |
"Chronic obstructive pulmonary disease (COPD) often involves chronic cough and dyspnea but is less associated with night sweats.", | |
"Fungal lung infections like histoplasmosis can mimic TB symptoms but are more common in specific endemic regions.", | |
"Gastroesophageal reflux disease (GERD) can cause chronic cough, but without systemic symptoms like weight loss or fever." | |
] | |
example_5 = [ | |
"How can I set up a recurring payment for my monthly rent via online banking?", | |
"A standing order allows you to set up automatic fixed-amount payments on a regular schedule (e.g., monthly rent) through your bank.", | |
"A direct debit authorizes a third party to withdraw variable amounts from your account, typically used for utility bills.", | |
"Wire transfers are typically one-off payments that do not recur automatically.", | |
"You can schedule a one-time payment for a future date using the online banking portal, but it won’t repeat monthly.", | |
"Bank-issued cashier’s checks are used for large payments but require manual setup each time.", | |
"To set up recurring credit card payments, navigate to your card provider’s auto-pay settings (note: for card bills only).", | |
"Standing orders can be modified or canceled at any time via your online banking dashboard.", | |
"International transfers may incur additional fees and are not ideal for domestic rent payments." | |
] | |
examples = [ | |
example + [""] * (int(os.getenv("MAX_DOCS")) - len(example) -1) for example in [example_1, example_2, example_3, example_4, example_5] | |
] | |
def classification(*args) -> List[str]: | |
labels = [arg for arg in args[1:]] | |
labels = list(filter(None, labels)) | |
query = args[0] | |
results = sorted(multi_label_pipeline(query, labels, threshold=0.0)[0], key=lambda x: x["score"], reverse=True) | |
docs = [] | |
scores = [] | |
for predict in results: | |
docs.append(predict["label"]) | |
scores.append(round(predict["score"], 2)) | |
for _ in range(int(os.getenv("MAX_DOCS")) - len(docs)): | |
docs.append("") | |
scores.append("") | |
return docs + scores | |
with gr.Blocks(title="GLiClass-Reranker") as main_pipeline: | |
inputs = [] | |
outputs = [] | |
query = gr.Textbox( | |
value=examples[0][0], label="Text query", placeholder="Enter your query here", lines=10 | |
) | |
submit_btn = gr.Button("Rerank") | |
inputs.append(query) | |
for i in range(int(os.getenv("MAX_DOCS"))): | |
with gr.Group(): | |
doc_input = gr.Textbox( | |
value=examples[0][1+i], | |
label=f"Document {i}", | |
placeholder="Enter your labels here (comma separated)", | |
scale=2, | |
) | |
score_output = gr.Textbox( | |
label=f"Score {i}", | |
placeholder="Score will appear here", | |
scale=2, | |
) | |
inputs.append(doc_input) | |
outputs.append(score_output) | |
outputs = inputs[1:] + outputs | |
examples = gr.Examples( | |
examples=examples, | |
fn=classification, | |
inputs=inputs, | |
outputs=outputs, | |
cache_examples=True, | |
) | |
submit_btn.click( | |
fn=classification, inputs=inputs, outputs=outputs | |
) | |