KemmerEdition commited on
Commit
d81b5f5
·
1 Parent(s): 5721366

add lab materials

Browse files
Files changed (3) hide show
  1. app.py +131 -0
  2. categories.csv +10 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+
8
+ @st.cache_resource
9
+ def pipeline_getter():
10
+ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
11
+ model = AutoModelForSequenceClassification.from_pretrained('KemmerEdition/my-distill-classifier')
12
+ mapping = pd.read_csv('./categories.csv').values.squeeze()
13
+ return tokenizer, model, mapping
14
+
15
+
16
+ tokenizer, model, mapping = pipeline_getter()
17
+
18
+
19
+ def predict_article_categories_with_confidence(
20
+ text_data,
21
+ abstract_text=None,
22
+ confidence_level=0.95,
23
+ max_categories=9
24
+ ):
25
+ tokenized_input = tokenizer(
26
+ text=text_data,
27
+ text_pair=abstract_text,
28
+ padding=True,
29
+ truncation=True,
30
+ return_tensors='pt'
31
+ )
32
+
33
+ model_output = model(**tokenized_input)
34
+ logits = model_output.logits
35
+ probs = torch.sigmoid(logits).detach().numpy().flatten()
36
+
37
+ sorted_indices = np.argsort(probs)[::-1]
38
+ sorted_probs = probs[sorted_indices]
39
+
40
+ cumulative_probs = np.cumsum(sorted_probs)
41
+
42
+ selected_indices = []
43
+ for i, cum_prob in enumerate(cumulative_probs):
44
+ if cum_prob >= confidence_level or i >= max_categories - 1:
45
+ selected_indices = sorted_indices[:i+1]
46
+ break
47
+
48
+ result = {
49
+ 'probabilities': probs,
50
+ 'predicted_categories': [mapping[idx] for idx in selected_indices],
51
+ 'confidence': cumulative_probs[len(selected_indices)-1],
52
+ 'top_category': mapping[sorted_indices[0]],
53
+ 'used_categories': len(selected_indices)
54
+ }
55
+
56
+ return result
57
+
58
+
59
+ st.markdown("""
60
+ <style>
61
+ .header {
62
+ font-size: 36px !important;
63
+ color: #1f77b4;
64
+ margin-bottom: 20px;
65
+ }
66
+ .input-box {
67
+ background-color: #f0f2f6;
68
+ padding: 20px;
69
+ border-radius: 10px;
70
+ margin-bottom: 20px;
71
+ }
72
+ .result-box {
73
+ background-color: #e6f3ff;
74
+ padding: 20px;
75
+ border-radius: 10px;
76
+ margin-top: 20px;
77
+ }
78
+ .category-badge {
79
+ display: inline-block;
80
+ background-color: #1f77b4;
81
+ color: white;
82
+ padding: 5px 10px;
83
+ margin: 5px;
84
+ border-radius: 15px;
85
+ font-size: 14px;
86
+ }
87
+ </style>
88
+ """, unsafe_allow_html=True)
89
+
90
+ st.markdown('<div class="header">Classificator of Paper from arxiv</div>', unsafe_allow_html=True)
91
+
92
+ with st.container():
93
+ st.markdown('<div class="input-box">', unsafe_allow_html=True)
94
+ title_input = st.text_input('**Here you can write title:**', placeholder="e.g. Quantum Machine Learning Approaches")
95
+ abstract_input = st.text_area('**Here you can write summary from arxiv:**',
96
+ placeholder="Paste the abstract here for more accurate categorization...",
97
+ height=150)
98
+ st.markdown('</div>', unsafe_allow_html=True)
99
+
100
+ col1, col2 = st.columns(2)
101
+ with col1:
102
+ confidence_level = st.slider('**Confidence level (%)**', 80, 100, 95)
103
+ with col2:
104
+ max_categories = st.slider('**Maximum categories**', 1, 10, 3)
105
+
106
+ if st.button('**Press F (just press)**', type="primary"):
107
+ if len(title_input) > 0:
108
+ with st.spinner('Analyzing paper content...'):
109
+ result = predict_article_categories_with_confidence(
110
+ title_input,
111
+ abstract_input if abstract_input else None,
112
+ confidence_level=confidence_level/100,
113
+ max_categories=max_categories
114
+ )
115
+
116
+ with st.container():
117
+ st.markdown('<div class="result-box">', unsafe_allow_html=True)
118
+ st.subheader("Categorization Results")
119
+
120
+ st.markdown(f"**Most likely category:**")
121
+ st.markdown(f'<div class="category-badge">{result["top_category"]} (p={result["probabilities"][np.argmax(result["probabilities"])]:.3f})</div>',
122
+ unsafe_allow_html=True)
123
+
124
+ if len(result["predicted_categories"]) > 1:
125
+ st.markdown(f"Additional categories:")
126
+ for category in result["predicted_categories"][1:]:
127
+ st.markdown(f'<div class="category-badge">{category}</div>', unsafe_allow_html=True)
128
+
129
+ st.markdown("---")
130
+ else:
131
+ st.warning("Please enter at least the paper title")
categories.csv ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ category
2
+ cs
3
+ econ
4
+ eess
5
+ math
6
+ nlin
7
+ physics
8
+ q-bio
9
+ q-fin
10
+ stat
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.47.0
2
+ torch==2.2.2
3
+ pandas==2.2.2
4
+ numpy==1.26.4