File size: 3,044 Bytes
b9f69ed
f37dabc
b9f69ed
 
f37dabc
b9f69ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf8e5ea
b9f69ed
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
from tokenizers import Tokenizer


def fake_hash(x):
    return 0


@st.cache(hash_funcs={Tokenizer: fake_hash}, suppress_st_warning=True, allow_output_mutation=True)
def initialize():
    model_name = 'distilbert-base-cased'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained('./final_model')

    the_pipeline = TextClassificationPipeline(
        model=model,
        tokenizer=tokenizer,
        return_all_scores=True,
        device=-1
    )

    cat_mapping_file = open('cat_mapping.json', 'r')
    cat_name_mapping_file = open('cat_name_mapping.json', 'r')
    cat_mapping = json.load(cat_mapping_file)
    cat_name_mapping = json.load(cat_name_mapping_file)

    return the_pipeline, cat_mapping, cat_name_mapping


def get_top(the_pipeline, cat_mapping, title, summary, thresh=0.95):
    if title == '' or summary == '':
        return 'Not enough data to compute.'
    
    question = title + ' || ' + summary
    if len(question) > 4000:
        return 'Your input is supsiciously long, try something shorter.'

    try:
        result = the_pipeline(question)[0]
        result.sort(key=lambda x: -x['score'])
        
        current_sum = 0
        scores = []
        
        for score in result:
            scores.append(score)
            current_sum += score['score']
            if current_sum >= thresh:
                break
        
        for i in range(len(result)):
            result[i]['label'] = cat_mapping[result[i]['label'][6:]]
        
        return scores

    except BaseException:
        return 'Something unexpected happened, I\'m sorry. Try again.'


st.markdown('## Welcome to the CS article classification page!')
st.markdown('### What\'s below is pretty much self-explanatory.')

img_source = 'https://sun9-55.userapi.com/impg/azBQ_VTvbgEVonbL9hhFEpwyKAhjAtpVl4H2GQ/I4Vq0H6c3UM.jpg'
img_params = 'size=1200x900&quality=96&sign=f42419d9cdbf6fe55016fb002e4e85ae&type=album'
st.markdown(
    f'<img src="{img_source}?{img_params}" width="70%"><br>',
    unsafe_allow_html=True
)

title = st.text_input(
    'Please, insert the title of the CS article you are interested in.',
    placeholder='The title (e. g. Incorporating alien technologies in CV)'
)

summary = st.text_area(
    'Now, please, insert the summary of the CS article you are interested in.',
    height=240, placeholder='The summary itself.'
)

the_pipeline, cat_mapping, cat_name_mapping = initialize()
scores = get_top(the_pipeline, cat_mapping, title, summary)

if isinstance(scores, str):
    st.markdown(scores)
else:
    for score in scores:
        percent = round(score['score'] * 100, 2)
        category_short = score['label']
        category_full = cat_name_mapping[category_short]
        st.markdown(f'I\'m {percent}\% certain that the article is from the {category_short} category, which is "{category_full}"')