KemmerEdition's picture
add lab materials
d81b5f5
raw
history blame
4.38 kB
import streamlit as st
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
@st.cache_resource
def pipeline_getter():
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
model = AutoModelForSequenceClassification.from_pretrained('KemmerEdition/my-distill-classifier')
mapping = pd.read_csv('./categories.csv').values.squeeze()
return tokenizer, model, mapping
tokenizer, model, mapping = pipeline_getter()
def predict_article_categories_with_confidence(
text_data,
abstract_text=None,
confidence_level=0.95,
max_categories=9
):
tokenized_input = tokenizer(
text=text_data,
text_pair=abstract_text,
padding=True,
truncation=True,
return_tensors='pt'
)
model_output = model(**tokenized_input)
logits = model_output.logits
probs = torch.sigmoid(logits).detach().numpy().flatten()
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
selected_indices = []
for i, cum_prob in enumerate(cumulative_probs):
if cum_prob >= confidence_level or i >= max_categories - 1:
selected_indices = sorted_indices[:i+1]
break
result = {
'probabilities': probs,
'predicted_categories': [mapping[idx] for idx in selected_indices],
'confidence': cumulative_probs[len(selected_indices)-1],
'top_category': mapping[sorted_indices[0]],
'used_categories': len(selected_indices)
}
return result
st.markdown("""
<style>
.header {
font-size: 36px !important;
color: #1f77b4;
margin-bottom: 20px;
}
.input-box {
background-color: #f0f2f6;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
}
.result-box {
background-color: #e6f3ff;
padding: 20px;
border-radius: 10px;
margin-top: 20px;
}
.category-badge {
display: inline-block;
background-color: #1f77b4;
color: white;
padding: 5px 10px;
margin: 5px;
border-radius: 15px;
font-size: 14px;
}
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="header">Classificator of Paper from arxiv</div>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="input-box">', unsafe_allow_html=True)
title_input = st.text_input('**Here you can write title:**', placeholder="e.g. Quantum Machine Learning Approaches")
abstract_input = st.text_area('**Here you can write summary from arxiv:**',
placeholder="Paste the abstract here for more accurate categorization...",
height=150)
st.markdown('</div>', unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
confidence_level = st.slider('**Confidence level (%)**', 80, 100, 95)
with col2:
max_categories = st.slider('**Maximum categories**', 1, 10, 3)
if st.button('**Press F (just press)**', type="primary"):
if len(title_input) > 0:
with st.spinner('Analyzing paper content...'):
result = predict_article_categories_with_confidence(
title_input,
abstract_input if abstract_input else None,
confidence_level=confidence_level/100,
max_categories=max_categories
)
with st.container():
st.markdown('<div class="result-box">', unsafe_allow_html=True)
st.subheader("Categorization Results")
st.markdown(f"**Most likely category:**")
st.markdown(f'<div class="category-badge">{result["top_category"]} (p={result["probabilities"][np.argmax(result["probabilities"])]:.3f})</div>',
unsafe_allow_html=True)
if len(result["predicted_categories"]) > 1:
st.markdown(f"Additional categories:")
for category in result["predicted_categories"][1:]:
st.markdown(f'<div class="category-badge">{category}</div>', unsafe_allow_html=True)
st.markdown("---")
else:
st.warning("Please enter at least the paper title")