Spaces:
Runtime error
Runtime error
Charles Kabui
commited on
Commit
·
661ce73
1
Parent(s):
46d9ca4
formating
Browse files
app.py
CHANGED
@@ -12,42 +12,56 @@ bi_encoder = 'Bi-Encoder'
|
|
12 |
cross_encoder = 'Cross-Encoder'
|
13 |
levenshtein_distance = 'Levenshtein Distance'
|
14 |
tf_idf = 'TF-IDF'
|
15 |
-
random_forest = '
|
16 |
title = 'Sentence Similarity with Transformers'
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
@st.cache_data
|
22 |
def compute_similarity(sentence_1, sentence_2, comparison):
|
23 |
if comparison == bi_encoder:
|
24 |
return cosine_similarity([bi_encoder_trasformer.encode(sentence_1)], [bi_encoder_trasformer.encode(sentence_2)])[0][0]
|
25 |
return cross_encoder_trasformer.predict([sentence_1, sentence_2])
|
26 |
|
27 |
-
|
|
|
28 |
st.title(title)
|
29 |
st.write("This app takes two sentences and outputs their similarity score using a fine-tuned transformer model.")
|
30 |
|
31 |
# Example sentences section
|
32 |
test_samples = get_samples()
|
33 |
st.sidebar.header("Example Sentences")
|
34 |
-
example_1 = st.sidebar.radio(
|
35 |
-
|
|
|
|
|
36 |
|
37 |
# Input fields
|
38 |
sentence_1 = st.text_input("Enter Sentence 1:", example_1)
|
39 |
sentence_2 = st.text_input("Enter Sentence 2:", example_2)
|
40 |
-
comparison = st.selectbox("Comparicon:", [
|
|
|
41 |
|
42 |
if st.button("Compare"):
|
43 |
# Compute similarity
|
44 |
if comparison in [bi_encoder, cross_encoder]:
|
45 |
similarity = compute_similarity(sentence_1, sentence_2, comparison)
|
46 |
elif comparison == levenshtein_distance:
|
47 |
-
similarity = textdistance.levenshtein.normalized_similarity(
|
|
|
48 |
elif comparison == tf_idf:
|
49 |
-
similarity = cosine_similarity(
|
|
|
50 |
elif comparison == random_forest:
|
51 |
-
similarity = random_forest_model.predict(encode_sentences(
|
52 |
-
|
53 |
-
st.
|
|
|
|
|
|
|
|
12 |
cross_encoder = 'Cross-Encoder'
|
13 |
levenshtein_distance = 'Levenshtein Distance'
|
14 |
tf_idf = 'TF-IDF'
|
15 |
+
random_forest = 'Random Forest'
|
16 |
title = 'Sentence Similarity with Transformers'
|
17 |
+
st.set_page_config(page_title=title, layout='wide', initial_sidebar_state='auto')
|
18 |
+
@st.cache_data
|
19 |
+
def cache_variables():
|
20 |
+
tfidf_vectorizer = TfidfVectorizer()
|
21 |
+
cross_encoder_trasformer = CrossEncoder(model_save_path)
|
22 |
+
bi_encoder_trasformer = SentenceTransformer(model_save_path)
|
23 |
+
random_forest_model = joblib.load('trained_model_random_forest.joblib')
|
24 |
+
return tfidf_vectorizer, cross_encoder_trasformer, bi_encoder_trasformer, random_forest_model
|
25 |
+
|
26 |
@st.cache_data
|
27 |
def compute_similarity(sentence_1, sentence_2, comparison):
|
28 |
if comparison == bi_encoder:
|
29 |
return cosine_similarity([bi_encoder_trasformer.encode(sentence_1)], [bi_encoder_trasformer.encode(sentence_2)])[0][0]
|
30 |
return cross_encoder_trasformer.predict([sentence_1, sentence_2])
|
31 |
|
32 |
+
tfidf_vectorizer, cross_encoder_trasformer, bi_encoder_trasformer, random_forest_model = cache_variables()
|
33 |
+
|
34 |
st.title(title)
|
35 |
st.write("This app takes two sentences and outputs their similarity score using a fine-tuned transformer model.")
|
36 |
|
37 |
# Example sentences section
|
38 |
test_samples = get_samples()
|
39 |
st.sidebar.header("Example Sentences")
|
40 |
+
example_1 = st.sidebar.radio(
|
41 |
+
"Sentence 1", test_samples['sentence1'].values.tolist())
|
42 |
+
example_2 = st.sidebar.radio(
|
43 |
+
"Sentence 2", test_samples['sentence2'].values.tolist())
|
44 |
|
45 |
# Input fields
|
46 |
sentence_1 = st.text_input("Enter Sentence 1:", example_1)
|
47 |
sentence_2 = st.text_input("Enter Sentence 2:", example_2)
|
48 |
+
comparison = st.selectbox("Comparicon:", [
|
49 |
+
bi_encoder, cross_encoder, levenshtein_distance, tf_idf, random_forest])
|
50 |
|
51 |
if st.button("Compare"):
|
52 |
# Compute similarity
|
53 |
if comparison in [bi_encoder, cross_encoder]:
|
54 |
similarity = compute_similarity(sentence_1, sentence_2, comparison)
|
55 |
elif comparison == levenshtein_distance:
|
56 |
+
similarity = textdistance.levenshtein.normalized_similarity(
|
57 |
+
sentence_1, sentence_2)
|
58 |
elif comparison == tf_idf:
|
59 |
+
similarity = cosine_similarity(
|
60 |
+
tfidf_vectorizer.fit_transform([sentence_1, sentence_2]))[0][1]
|
61 |
elif comparison == random_forest:
|
62 |
+
similarity = random_forest_model.predict(encode_sentences(
|
63 |
+
bi_encoder_trasformer, sentence_1, sentence_2))[0]
|
64 |
+
st.markdown(
|
65 |
+
f"<b style='font-size: 1.5em'>{comparison}</b> similarity score: <b style='font-size: 1.5em'>:red[{similarity:.4f}]</b>", unsafe_allow_html=True)
|
66 |
+
st.write(
|
67 |
+
"A higher score indicates greater similarity. The score ranges from 0 to 1.")
|