Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -27,26 +27,24 @@ class VectorStore:
|
|
27 |
self.chroma_client = chromadb.Client()
|
28 |
self.collection = self.chroma_client.create_collection(name=collection_name)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
for text, embedding, doc_id in zip(texts, embeddings, ids):
|
35 |
-
self.collection.add(embeddings=[embedding], documents=[text], ids=[doc_id])
|
36 |
|
37 |
# Method to populate the vector store with embeddings from a dataset
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
def search_context(self, query, n_results=1):
|
52 |
query_embedding = self.embedding_model.encode([query]).tolist()
|
|
|
27 |
self.chroma_client = chromadb.Client()
|
28 |
self.collection = self.chroma_client.create_collection(name=collection_name)
|
29 |
|
30 |
+
# def populate_vectors(self, texts):
|
31 |
+
# embeddings = self.embedding_model.encode(texts, batch_size=32).tolist()
|
32 |
+
# for text, embedding in zip(texts, embeddings, ids):
|
33 |
+
# self.collection.add(embeddings=[embedding], documents=[text], ids=[doc_id])
|
|
|
|
|
34 |
|
35 |
# Method to populate the vector store with embeddings from a dataset
|
36 |
+
def populate_vectors(self, dataset):
|
37 |
+
# Select the text columns to concatenate
|
38 |
+
# title = dataset['train']['title_cleaned'][:1000] # Limiting to 100 examples for the demo
|
39 |
+
recipe = dataset['train']['recipe_new'][:1000]
|
40 |
+
allergy = dataset['train']['allergy_type'][:1000]
|
41 |
+
ingredients = dataset['train']['ingredients_alternatives'][:1000]
|
42 |
|
43 |
+
# Concatenate the text from both columns
|
44 |
+
texts = [f"{rep} {ingr} {alle}" for rep, ingr,alle in zip(recipe, ingredients,allergy)]
|
45 |
+
for i, item in enumerate(texts):
|
46 |
+
embeddings = self.embedding_model.encode(item).tolist()
|
47 |
+
self.collection.add(embeddings=[embeddings], documents=[item], ids=[str(i)])
|
48 |
|
49 |
def search_context(self, query, n_results=1):
|
50 |
query_embedding = self.embedding_model.encode([query]).tolist()
|