Spaces:
Sleeping
Sleeping
Update app.py
Browse filesmade an embeddings class, switched to roBERTa, and fixed the semantic comparison function
app.py
CHANGED
@@ -15,72 +15,100 @@ writing_tones = ["Formal","Informal","Humorous","Serious","Sarcastic","Satirical
|
|
15 |
|
16 |
# initialize client
|
17 |
# we could try something larger, I need to check the models
|
|
|
18 |
client = InferenceClient(
|
19 |
-
|
|
|
|
|
20 |
)
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
def get_embeddings(text_items):
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
#
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
# helper function to format the prompt appropriately.
|
73 |
# For this creative writing tool, the user doesn't design the prompt itself
|
74 |
#but rather genres, tones, & themes of a book to include
|
75 |
-
def format_prompt(
|
76 |
-
#
|
|
|
|
|
|
|
77 |
if not genres:
|
78 |
-
|
79 |
if not tones:
|
80 |
-
|
81 |
if not themes:
|
82 |
-
|
83 |
|
|
|
|
|
|
|
84 |
#Alpaca format since we can't use mixtral on free CPU settings
|
85 |
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n"
|
86 |
#prompt we are using for now
|
|
|
15 |
|
16 |
# initialize client
|
17 |
# we could try something larger, I need to check the models
|
18 |
+
# using zephyr for now because its pretty quick
|
19 |
client = InferenceClient(
|
20 |
+
'HuggingFaceH4/zephyr-7b-beta'
|
21 |
+
# "v1olet/v1olet_marcoroni-go-bruins-merge-7B"
|
22 |
+
|
23 |
)
|
24 |
+
######################################
|
25 |
+
########## Embeddings Class ##########
|
26 |
+
######################################
|
27 |
+
class EmbeddingGenerator:
|
28 |
+
def __init__(self, model_id):
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
30 |
+
self.model = AutoModel.from_pretrained(model_id)
|
31 |
+
|
32 |
+
def calculate_cosine_similarity(self, embedding1, embedding2):
|
33 |
+
return cosine_similarity(embedding1, embedding2)[0][0]
|
34 |
+
|
35 |
+
def get_embeddings(self, text_items):
|
36 |
+
embeddings = []
|
37 |
+
for item in text_items:
|
38 |
+
inputs = self.tokenizer(item, return_tensors="pt", padding=True, truncation=True)
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = self.model(**inputs)
|
41 |
+
pooled_output = outputs['pooler_output']
|
42 |
+
embeddings.append((pooled_output, item)) # Store the embedding along with the original string
|
43 |
+
return embeddings
|
44 |
+
|
45 |
+
def select_values_with_medium_similarity(self, embeddings, num_values_to_select, min_similarity, max_similarity):
|
46 |
+
selected_values = []
|
47 |
+
selected_indices = set()
|
48 |
+
|
49 |
+
# Randomly select an initial embedding
|
50 |
+
initial_index = random.randint(0, len(embeddings) - 1)
|
51 |
+
initial_embedding, initial_item = embeddings[initial_index]
|
52 |
+
selected_values.append(initial_item)
|
53 |
+
selected_indices.add(initial_index)
|
54 |
+
|
55 |
+
while len(selected_values) < num_values_to_select:
|
56 |
+
# Filter embeddings that are within the desired range
|
57 |
+
candidate_indices = [
|
58 |
+
i for i, (embedding, _) in enumerate(embeddings)
|
59 |
+
if i not in selected_indices and min_similarity < self.calculate_cosine_similarity(embedding, initial_embedding) < max_similarity
|
60 |
+
]
|
61 |
+
|
62 |
+
if candidate_indices:
|
63 |
+
# Randomly select an embedding from the filtered candidates
|
64 |
+
index_to_select = random.choice(candidate_indices)
|
65 |
+
selected_embedding, selected_item = embeddings[index_to_select]
|
66 |
+
selected_values.append(selected_item)
|
67 |
+
selected_indices.add(index_to_select)
|
68 |
+
else:
|
69 |
+
break
|
70 |
+
|
71 |
+
# Concatenate the selected values into a single string
|
72 |
+
selected_string = ', '.join(selected_values)
|
73 |
+
return selected_string
|
74 |
+
|
75 |
+
# testing different embeddings models that can fit in colab,
|
76 |
+
# need something smallish but also one that can create good semantic word embeddings for cosine similarity to work well
|
77 |
+
#model_id = "sentence-transformers/all-MiniLM-L6-v2"
|
78 |
+
#model_id = "BAAI/bge-small-en-v1.5"
|
79 |
+
# idk if this will work with CPUs, will either be too slow or too big
|
80 |
+
model_id = 'roberta-base'
|
81 |
+
|
82 |
+
#instantiate our class
|
83 |
+
embedding_generator = EmbeddingGenerator(model_id)
|
84 |
+
|
85 |
+
#generate embeddings
|
86 |
+
genre_embeddings = embedding_generator.get_embeddings(book_genres)
|
87 |
+
theme_embeddings = embedding_generator.get_embeddings(book_themes)
|
88 |
+
tone_embeddings = embedding_generator.get_embeddings(writing_tones)
|
89 |
+
|
90 |
+
# Clear memory
|
91 |
+
del embedding_generator
|
92 |
+
# torch.cuda.empty_cache()
|
93 |
|
94 |
# helper function to format the prompt appropriately.
|
95 |
# For this creative writing tool, the user doesn't design the prompt itself
|
96 |
#but rather genres, tones, & themes of a book to include
|
97 |
+
def format_prompt(genres, tones, themes):
|
98 |
+
#reinstantiate our embeddings class so we can compare the embeddings
|
99 |
+
embedding_generator = EmbeddingGenerator("roberta-base")
|
100 |
+
# pick 2-5 random ones if user leaves the field blank
|
101 |
+
# lower threshold is to avoid selecting synonyms while upper threshold is to avoid antonyms
|
102 |
if not genres:
|
103 |
+
genres = embedding_generator.select_values_with_medium_similarity(genre_embeddings, random.randint(3, 5), 0.01, 0.7) # Adjust thresholds as needed
|
104 |
if not tones:
|
105 |
+
tones = embedding_generator.select_values_with_medium_similarity(tone_embeddings, random.randint(3, 5), 0.01, 0.7) # Adjust thresholds as needed
|
106 |
if not themes:
|
107 |
+
themes = embedding_generator.select_values_with_medium_similarity(theme_embeddings, random.randint(3, 5), 0.01, 0.7) # Adjust thresholds as needed
|
108 |
|
109 |
+
# we won't need our embeddings generator after this step
|
110 |
+
del embedding_generator
|
111 |
+
|
112 |
#Alpaca format since we can't use mixtral on free CPU settings
|
113 |
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n"
|
114 |
#prompt we are using for now
|