import streamlit as st from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch import numpy as np MAPPING = { 'cs': 'Computer Science', 'stat': 'Statistics', 'math': 'Mathematics', 'q-bio': 'Quantitative Biology', 'physics': 'Physics', 'cmpl-lg': 'Computation and Language', 'eess': 'Electrical Engineering and Systems Science', 'quant-ph': 'Quantum Physics', 'cond-mat': 'Condensed Matter', 'astro-ph': 'Astrophysics', 'nlin': 'Nonlinear Sciences', 'q-fin': 'Quantitative Finance', ':)': 'Something else' } @st.cache_resource def load_model(): tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') model = DistilBertForSequenceClassification.from_pretrained('model/') return tokenizer, model tokenizer, model = load_model() st.title('arXiv Article Classifier') title = st.text_input('Title') abstract = st.text_area('Abstract') text = title + ' ' + abstract if abstract else title if st.button('Predict'): if not text.strip(): st.error('Please enter at least a title.') else: inputs = tokenizer( text, truncation=True, padding=True, max_length=512, return_tensors='pt' ) with torch.no_grad(): logits = model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=1).numpy()[0] sorted_indices = np.argsort(-probs) cumulative = 0 result = [] for idx in sorted_indices: cumulative += probs[idx] result.append((model.config.id2label[idx], probs[idx])) if cumulative >= 0.95: break for tag, prob in result: if tag in MAPPING: st.write(f'{MAPPING[tag]}: {prob:.2%}') else: st.write(f'{tag}: {prob:.2%}')