lu-ny commited on
Commit
b7fa6f0
·
1 Parent(s): f007a0a

Update app.py

Browse files

made an embeddings class, switched to roBERTa, and fixed the semantic comparison function

Files changed (1) hide show
  1. app.py +84 -56
app.py CHANGED
@@ -15,72 +15,100 @@ writing_tones = ["Formal","Informal","Humorous","Serious","Sarcastic","Satirical
15
 
16
  # initialize client
17
  # we could try something larger, I need to check the models
 
18
  client = InferenceClient(
19
- "v1olet/v1olet_marcoroni-go-bruins-merge-7B"
 
 
20
  )
21
-
22
- # Load pre-trained tokenizer model (replace with your desired model if you want, but it needs to be small)
23
- model_id = "sentence-transformers/all-MiniLM-L6-v2" #small embeddings model
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- model = AutoModel.from_pretrained(model_id) #cant do this since HF free uses cpu #load_in_4bit=True,
26
-
27
- # Function to calculate cosine similarity between two embeddings
28
- def calculate_cosine_similarity(embedding1, embedding2):
29
- return cosine_similarity(embedding1, embedding2)[0][0]
30
-
31
- # Function to convert text items into embeddings
32
- def get_embeddings(text_items):
33
- embeddings = []
34
- for item in text_items:
35
- inputs = tokenizer(item, return_tensors="pt", padding=True, truncation=True)
36
- with torch.no_grad():
37
- outputs = model(**inputs)
38
- pooled_output = outputs['pooler_output']
39
- embeddings.append(pooled_output)
40
- return embeddings
41
-
42
- # Helper Function to select values with small enough cosine similarity and concatenate them into a string
43
- def select_values_with_low_similarity(embeddings, original_values, num_values_to_select, max_similarity):
44
- selected_values = []
45
- selected_indices = set()
46
-
47
- while len(selected_values) < num_values_to_select:
48
- index1, index2 = random.sample(range(len(embeddings)), 2)
49
- embedding1, embedding2 = embeddings[index1], embeddings[index2]
50
-
51
- if index1 != index2 and calculate_cosine_similarity(embedding1, embedding2) < max_similarity:
52
- if index1 not in selected_indices:
53
- selected_values.append(original_values[index1])
54
- selected_indices.add(index1)
55
- if index2 not in selected_indices:
56
- selected_values.append(original_values[index2])
57
- selected_indices.add(index2)
58
-
59
- # Concatenate the selected values into a single string
60
- selected_string = ', '.join(selected_values)
61
- return selected_string
62
-
63
-
64
- # Convert text items into embeddings
65
- genre_embeddings = get_embeddings(book_genres)
66
- theme_embeddings = get_embeddings(book_themes)
67
- tone_embeddings = get_embeddings(writing_tones)
68
- #clear memory
69
- del model
70
- #torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # helper function to format the prompt appropriately.
73
  # For this creative writing tool, the user doesn't design the prompt itself
74
  #but rather genres, tones, & themes of a book to include
75
- def format_prompt(message, genres, tones, themes):
76
- # pick random ones if user leaves it blank but make sure they aren't opposites
 
 
 
77
  if not genres:
78
- selected_genres = select_values_with_low_similarity(genre_embeddings, book_genres, random.randint(3, 5), 0.2) # Adjust threshold as needed
79
  if not tones:
80
- selected_tones = select_values_with_low_similarity(tone_embeddings, writing_tones, random.randint(3, 5), 0.2) # Adjust threshold as needed
81
  if not themes:
82
- selected_themes = select_values_with_low_similarity(theme_embeddings, book_themes, random.randint(3, 5), 0.2) # Adjust threshold as needed
83
 
 
 
 
84
  #Alpaca format since we can't use mixtral on free CPU settings
85
  prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n"
86
  #prompt we are using for now
 
15
 
16
  # initialize client
17
  # we could try something larger, I need to check the models
18
+ # using zephyr for now because its pretty quick
19
  client = InferenceClient(
20
+ 'HuggingFaceH4/zephyr-7b-beta'
21
+ # "v1olet/v1olet_marcoroni-go-bruins-merge-7B"
22
+
23
  )
24
+ ######################################
25
+ ########## Embeddings Class ##########
26
+ ######################################
27
+ class EmbeddingGenerator:
28
+ def __init__(self, model_id):
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
30
+ self.model = AutoModel.from_pretrained(model_id)
31
+
32
+ def calculate_cosine_similarity(self, embedding1, embedding2):
33
+ return cosine_similarity(embedding1, embedding2)[0][0]
34
+
35
+ def get_embeddings(self, text_items):
36
+ embeddings = []
37
+ for item in text_items:
38
+ inputs = self.tokenizer(item, return_tensors="pt", padding=True, truncation=True)
39
+ with torch.no_grad():
40
+ outputs = self.model(**inputs)
41
+ pooled_output = outputs['pooler_output']
42
+ embeddings.append((pooled_output, item)) # Store the embedding along with the original string
43
+ return embeddings
44
+
45
+ def select_values_with_medium_similarity(self, embeddings, num_values_to_select, min_similarity, max_similarity):
46
+ selected_values = []
47
+ selected_indices = set()
48
+
49
+ # Randomly select an initial embedding
50
+ initial_index = random.randint(0, len(embeddings) - 1)
51
+ initial_embedding, initial_item = embeddings[initial_index]
52
+ selected_values.append(initial_item)
53
+ selected_indices.add(initial_index)
54
+
55
+ while len(selected_values) < num_values_to_select:
56
+ # Filter embeddings that are within the desired range
57
+ candidate_indices = [
58
+ i for i, (embedding, _) in enumerate(embeddings)
59
+ if i not in selected_indices and min_similarity < self.calculate_cosine_similarity(embedding, initial_embedding) < max_similarity
60
+ ]
61
+
62
+ if candidate_indices:
63
+ # Randomly select an embedding from the filtered candidates
64
+ index_to_select = random.choice(candidate_indices)
65
+ selected_embedding, selected_item = embeddings[index_to_select]
66
+ selected_values.append(selected_item)
67
+ selected_indices.add(index_to_select)
68
+ else:
69
+ break
70
+
71
+ # Concatenate the selected values into a single string
72
+ selected_string = ', '.join(selected_values)
73
+ return selected_string
74
+
75
+ # testing different embeddings models that can fit in colab,
76
+ # need something smallish but also one that can create good semantic word embeddings for cosine similarity to work well
77
+ #model_id = "sentence-transformers/all-MiniLM-L6-v2"
78
+ #model_id = "BAAI/bge-small-en-v1.5"
79
+ # idk if this will work with CPUs, will either be too slow or too big
80
+ model_id = 'roberta-base'
81
+
82
+ #instantiate our class
83
+ embedding_generator = EmbeddingGenerator(model_id)
84
+
85
+ #generate embeddings
86
+ genre_embeddings = embedding_generator.get_embeddings(book_genres)
87
+ theme_embeddings = embedding_generator.get_embeddings(book_themes)
88
+ tone_embeddings = embedding_generator.get_embeddings(writing_tones)
89
+
90
+ # Clear memory
91
+ del embedding_generator
92
+ # torch.cuda.empty_cache()
93
 
94
  # helper function to format the prompt appropriately.
95
  # For this creative writing tool, the user doesn't design the prompt itself
96
  #but rather genres, tones, & themes of a book to include
97
+ def format_prompt(genres, tones, themes):
98
+ #reinstantiate our embeddings class so we can compare the embeddings
99
+ embedding_generator = EmbeddingGenerator("roberta-base")
100
+ # pick 2-5 random ones if user leaves the field blank
101
+ # lower threshold is to avoid selecting synonyms while upper threshold is to avoid antonyms
102
  if not genres:
103
+ genres = embedding_generator.select_values_with_medium_similarity(genre_embeddings, random.randint(3, 5), 0.01, 0.7) # Adjust thresholds as needed
104
  if not tones:
105
+ tones = embedding_generator.select_values_with_medium_similarity(tone_embeddings, random.randint(3, 5), 0.01, 0.7) # Adjust thresholds as needed
106
  if not themes:
107
+ themes = embedding_generator.select_values_with_medium_similarity(theme_embeddings, random.randint(3, 5), 0.01, 0.7) # Adjust thresholds as needed
108
 
109
+ # we won't need our embeddings generator after this step
110
+ del embedding_generator
111
+
112
  #Alpaca format since we can't use mixtral on free CPU settings
113
  prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n"
114
  #prompt we are using for now