Charles Kabui commited on
Commit
661ce73
·
1 Parent(s): 46d9ca4
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -12,42 +12,56 @@ bi_encoder = 'Bi-Encoder'
12
  cross_encoder = 'Cross-Encoder'
13
  levenshtein_distance = 'Levenshtein Distance'
14
  tf_idf = 'TF-IDF'
15
- random_forest = 'RandomForest'
16
  title = 'Sentence Similarity with Transformers'
17
- tfidf_vectorizer = TfidfVectorizer()
18
- cross_encoder_trasformer = CrossEncoder(model_save_path)
19
- bi_encoder_trasformer = SentenceTransformer(model_save_path)
20
- random_forest_model = joblib.load('trained_model_random_forest.joblib')
 
 
 
 
 
21
  @st.cache_data
22
  def compute_similarity(sentence_1, sentence_2, comparison):
23
  if comparison == bi_encoder:
24
  return cosine_similarity([bi_encoder_trasformer.encode(sentence_1)], [bi_encoder_trasformer.encode(sentence_2)])[0][0]
25
  return cross_encoder_trasformer.predict([sentence_1, sentence_2])
26
 
27
- st.set_page_config(page_title=title, layout = 'wide', initial_sidebar_state = 'auto')
 
28
  st.title(title)
29
  st.write("This app takes two sentences and outputs their similarity score using a fine-tuned transformer model.")
30
 
31
  # Example sentences section
32
  test_samples = get_samples()
33
  st.sidebar.header("Example Sentences")
34
- example_1 = st.sidebar.radio("Sentence 1", test_samples['sentence1'].values.tolist())
35
- example_2 = st.sidebar.radio("Sentence 2", test_samples['sentence2'].values.tolist())
 
 
36
 
37
  # Input fields
38
  sentence_1 = st.text_input("Enter Sentence 1:", example_1)
39
  sentence_2 = st.text_input("Enter Sentence 2:", example_2)
40
- comparison = st.selectbox("Comparicon:", [bi_encoder, cross_encoder, levenshtein_distance, tf_idf, random_forest])
 
41
 
42
  if st.button("Compare"):
43
  # Compute similarity
44
  if comparison in [bi_encoder, cross_encoder]:
45
  similarity = compute_similarity(sentence_1, sentence_2, comparison)
46
  elif comparison == levenshtein_distance:
47
- similarity = textdistance.levenshtein.normalized_similarity(sentence_1, sentence_2)
 
48
  elif comparison == tf_idf:
49
- similarity = cosine_similarity(tfidf_vectorizer.fit_transform([sentence_1, sentence_2]))[0][1]
 
50
  elif comparison == random_forest:
51
- similarity = random_forest_model.predict(encode_sentences(bi_encoder_trasformer, sentence_1, sentence_2))[0]
52
- st.write(f"Similarity Score: {similarity:.4f}")
53
- st.write("A higher score indicates greater similarity. The score ranges from 0 to 1.")
 
 
 
 
12
  cross_encoder = 'Cross-Encoder'
13
  levenshtein_distance = 'Levenshtein Distance'
14
  tf_idf = 'TF-IDF'
15
+ random_forest = 'Random Forest'
16
  title = 'Sentence Similarity with Transformers'
17
+ st.set_page_config(page_title=title, layout='wide', initial_sidebar_state='auto')
18
+ @st.cache_data
19
+ def cache_variables():
20
+ tfidf_vectorizer = TfidfVectorizer()
21
+ cross_encoder_trasformer = CrossEncoder(model_save_path)
22
+ bi_encoder_trasformer = SentenceTransformer(model_save_path)
23
+ random_forest_model = joblib.load('trained_model_random_forest.joblib')
24
+ return tfidf_vectorizer, cross_encoder_trasformer, bi_encoder_trasformer, random_forest_model
25
+
26
  @st.cache_data
27
  def compute_similarity(sentence_1, sentence_2, comparison):
28
  if comparison == bi_encoder:
29
  return cosine_similarity([bi_encoder_trasformer.encode(sentence_1)], [bi_encoder_trasformer.encode(sentence_2)])[0][0]
30
  return cross_encoder_trasformer.predict([sentence_1, sentence_2])
31
 
32
+ tfidf_vectorizer, cross_encoder_trasformer, bi_encoder_trasformer, random_forest_model = cache_variables()
33
+
34
  st.title(title)
35
  st.write("This app takes two sentences and outputs their similarity score using a fine-tuned transformer model.")
36
 
37
  # Example sentences section
38
  test_samples = get_samples()
39
  st.sidebar.header("Example Sentences")
40
+ example_1 = st.sidebar.radio(
41
+ "Sentence 1", test_samples['sentence1'].values.tolist())
42
+ example_2 = st.sidebar.radio(
43
+ "Sentence 2", test_samples['sentence2'].values.tolist())
44
 
45
  # Input fields
46
  sentence_1 = st.text_input("Enter Sentence 1:", example_1)
47
  sentence_2 = st.text_input("Enter Sentence 2:", example_2)
48
+ comparison = st.selectbox("Comparicon:", [
49
+ bi_encoder, cross_encoder, levenshtein_distance, tf_idf, random_forest])
50
 
51
  if st.button("Compare"):
52
  # Compute similarity
53
  if comparison in [bi_encoder, cross_encoder]:
54
  similarity = compute_similarity(sentence_1, sentence_2, comparison)
55
  elif comparison == levenshtein_distance:
56
+ similarity = textdistance.levenshtein.normalized_similarity(
57
+ sentence_1, sentence_2)
58
  elif comparison == tf_idf:
59
+ similarity = cosine_similarity(
60
+ tfidf_vectorizer.fit_transform([sentence_1, sentence_2]))[0][1]
61
  elif comparison == random_forest:
62
+ similarity = random_forest_model.predict(encode_sentences(
63
+ bi_encoder_trasformer, sentence_1, sentence_2))[0]
64
+ st.markdown(
65
+ f"<b style='font-size: 1.5em'>{comparison}</b> similarity score: <b style='font-size: 1.5em'>:red[{similarity:.4f}]</b>", unsafe_allow_html=True)
66
+ st.write(
67
+ "A higher score indicates greater similarity. The score ranges from 0 to 1.")