dfedukov's picture
Update app.py
1fbf315 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from safetensors.torch import load_file as safe_load
target_to_ind = {'cs': 0, 'econ': 1, 'eess': 2, 'math': 3, 'phys': 4, 'q-bio': 5, 'q-fin': 6, 'stat': 7}
target_to_label = {'cs': 'Computer Science', 'econ': 'Economics', 'eess': 'Electrical Engineering and Systems Science', 'math': 'Mathematics', 'phys': 'Physics',
'q-bio': 'Quantitative Biology', 'q-fin': 'Quantitative Finance', 'stat': 'Statistics'}
ind_to_target = {ind: target for target, ind in target_to_ind.items()}
st.title('papers_classifier πŸ€“')
@st.cache_data
def display_intro():
intro_text = """
Hey! I'm papers_classifier and I'm here to help you with answering the question 'WTF is this paper about?'
According to arXiv there are 8 different fields of study:
- Computer Science
- Economics
- Electrical Engineering and Systems Science
- Mathematics
- Physics
- Quantitative Biology
- Quantitative Finance
- Statistics
Everything I'll tell you will be about these eight fields.
How to use me:
1. Give me paper's title and (if you have one) it's abstract
2. Choose one of two classification modes:
- Best prediction: Shows the most likely to be true field
- Top 95%: Shows multiple fields until I'm at least 95% confident that the correct one is among them
3. Press the 'Get prediction' button
4. Wait for me to tell you which fields of study this paper relates to
"""
st.markdown(intro_text)
# Call the function to display the introduction
display_intro()
@st.cache_resource
def load_model_and_tokenizer():
model_name = 'distilbert/distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(target_to_ind))
state_dict = safe_load("model (2).safetensors")
model.load_state_dict(state_dict)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
def get_predict(title: str, abstract: str) -> (str, float, dict):
text = [title + tokenizer.sep_token + abstract[:128]]
tokens_info = tokenizer(
text,
padding=True,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
out = model(**tokens_info)
probs = torch.nn.functional.softmax(out.logits, dim=-1).tolist()[0]
return list(sorted([(p, ind_to_target[i]) for i, p in enumerate(probs)]))[::-1]
title = st.text_area("Title ", "", height=100)
abstract = st.text_area("Abstract ", "", height=150)
mode = st.radio("Mode: ", ("Best prediction", "Top 95%"))
if st.button("Get prediction", key="manual"):
if len(title) == 0:
st.error("Please, provide paper's title")
else:
with st.spinner("Be patient, I'm doing my best"):
predict = get_predict(title, abstract)
tags = []
threshold = 0 if mode == "Best prediction" else 0.95
sum_p = 0
for p, tag in predict:
sum_p += p
tags.append(target_to_label[tag])
if sum_p >= threshold:
break
tags = '\n\n'.join(tags)
st.success(tags)