Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
def pipeline_getter(): | |
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased') | |
model = AutoModelForSequenceClassification.from_pretrained('KemmerEdition/my-distill-classifier') | |
mapping = pd.read_csv('./categories.csv').values.squeeze() | |
return tokenizer, model, mapping | |
tokenizer, model, mapping = pipeline_getter() | |
def predict_article_categories_with_confidence( | |
text_data, | |
abstract_text=None, | |
confidence_level=0.95, | |
max_categories=9 | |
): | |
tokenized_input = tokenizer( | |
text=text_data, | |
text_pair=abstract_text, | |
padding=True, | |
truncation=True, | |
return_tensors='pt' | |
) | |
model_output = model(**tokenized_input) | |
logits = model_output.logits | |
probs = torch.sigmoid(logits).detach().numpy().flatten() | |
sorted_indices = np.argsort(probs)[::-1] | |
sorted_probs = probs[sorted_indices] | |
cumulative_probs = np.cumsum(sorted_probs) | |
selected_indices = [] | |
for i, cum_prob in enumerate(cumulative_probs): | |
if cum_prob >= confidence_level or i >= max_categories - 1: | |
selected_indices = sorted_indices[:i+1] | |
break | |
result = { | |
'probabilities': probs, | |
'predicted_categories': [mapping[idx] for idx in selected_indices], | |
'confidence': cumulative_probs[len(selected_indices)-1], | |
'top_category': mapping[sorted_indices[0]], | |
'used_categories': len(selected_indices) | |
} | |
return result | |
st.markdown(""" | |
<style> | |
.header { | |
font-size: 36px !important; | |
color: #1f77b4; | |
margin-bottom: 20px; | |
} | |
.input-box { | |
background-color: #f0f2f6; | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
} | |
.result-box { | |
background-color: #e6f3ff; | |
padding: 20px; | |
border-radius: 10px; | |
margin-top: 20px; | |
} | |
.category-badge { | |
display: inline-block; | |
background-color: #1f77b4; | |
color: white; | |
padding: 5px 10px; | |
margin: 5px; | |
border-radius: 15px; | |
font-size: 14px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown('<div class="header">Classificator of Paper from arxiv</div>', unsafe_allow_html=True) | |
with st.container(): | |
st.markdown('<div class="input-box">', unsafe_allow_html=True) | |
title_input = st.text_input('**Here you can write title:**', placeholder="e.g. Quantum Machine Learning Approaches") | |
abstract_input = st.text_area('**Here you can write summary from arxiv:**', | |
placeholder="Paste the abstract here for more accurate categorization...", | |
height=150) | |
st.markdown('</div>', unsafe_allow_html=True) | |
col1, col2 = st.columns(2) | |
with col1: | |
confidence_level = st.slider('**Confidence level (%)**', 80, 100, 95) | |
with col2: | |
max_categories = st.slider('**Maximum categories**', 1, 10, 3) | |
if st.button('**Press F (just press)**', type="primary"): | |
if len(title_input) > 0: | |
with st.spinner('Analyzing paper content...'): | |
result = predict_article_categories_with_confidence( | |
title_input, | |
abstract_input if abstract_input else None, | |
confidence_level=confidence_level/100, | |
max_categories=max_categories | |
) | |
with st.container(): | |
st.markdown('<div class="result-box">', unsafe_allow_html=True) | |
st.subheader("Categorization Results") | |
st.markdown(f"**Most likely category:**") | |
st.markdown(f'<div class="category-badge">{result["top_category"]} (p={result["probabilities"][np.argmax(result["probabilities"])]:.3f})</div>', | |
unsafe_allow_html=True) | |
if len(result["predicted_categories"]) > 1: | |
st.markdown(f"Additional categories:") | |
for category in result["predicted_categories"][1:]: | |
st.markdown(f'<div class="category-badge">{category}</div>', unsafe_allow_html=True) | |
st.markdown("---") | |
else: | |
st.warning("Please enter at least the paper title") |