devjas1
commited on
Commit
·
eae02ee
1
Parent(s):
cede51b
(FEAT)[Implement FAISS Indexing]: enhance document embedding process by integrating FAISS indexing and saving metadata.
Browse files- src/embedder.py +52 -13
src/embedder.py
CHANGED
@@ -1,32 +1,71 @@
|
|
1 |
"""
|
2 |
This script handles document embedding using EmbeddingGemma.
|
3 |
This is the entry point for indexing documents.
|
4 |
-
TODO: Wire this to FAISS
|
5 |
"""
|
6 |
|
7 |
import os
|
|
|
|
|
|
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
|
10 |
|
11 |
def embed_documents(path: str, config: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
try:
|
13 |
model = SentenceTransformer(config["embedding"]["model_path"])
|
14 |
-
|
15 |
-
|
|
|
|
|
16 |
|
17 |
-
model = SentenceTransformer(config["embedding"]["model_path"])
|
18 |
embeddings = []
|
|
|
|
|
19 |
|
|
|
20 |
for fname in os.listdir(path):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
29 |
print(f"Total embeddings created: {len(embeddings)}")
|
30 |
-
return embeddings
|
31 |
|
32 |
-
|
|
|
1 |
"""
|
2 |
This script handles document embedding using EmbeddingGemma.
|
3 |
This is the entry point for indexing documents.
|
|
|
4 |
"""
|
5 |
|
6 |
import os
|
7 |
+
import pickle
|
8 |
+
import faiss
|
9 |
+
import numpy as np
|
10 |
from sentence_transformers import SentenceTransformer
|
11 |
|
12 |
|
13 |
def embed_documents(path: str, config: dict):
|
14 |
+
"""
|
15 |
+
Embed documents from a directory and save to FAISS index.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
path (str): Path to the directory containing the documents to embed.
|
19 |
+
config (dict): Configuration dictionary.
|
20 |
+
"""
|
21 |
try:
|
22 |
model = SentenceTransformer(config["embedding"]["model_path"])
|
23 |
+
print(f"Initalized embedding model: {config['embedding']['model_path']}")
|
24 |
+
except ValueError as e:
|
25 |
+
print(f"Error initializing embedding model: {e}")
|
26 |
+
return []
|
27 |
|
|
|
28 |
embeddings = []
|
29 |
+
texts = []
|
30 |
+
filenames = []
|
31 |
|
32 |
+
# Read all documents
|
33 |
for fname in os.listdir(path):
|
34 |
+
fpath = os.path.join(path, fname)
|
35 |
+
if os.path.isfile(fpath):
|
36 |
+
try:
|
37 |
+
with open(fpath, "r", encoding="utf-8") as f:
|
38 |
+
text = f.read()
|
39 |
+
if text.strip(): # Only process non-empty files
|
40 |
+
emb = model.encode(text)
|
41 |
+
embeddings.append(emb)
|
42 |
+
texts.append(text)
|
43 |
+
filenames.append(fname)
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error reading file {fpath}: {e}")
|
46 |
+
|
47 |
+
if not embeddings:
|
48 |
+
print("No documents were successfully embedded.")
|
49 |
+
return []
|
50 |
+
|
51 |
+
# Create FAISS index
|
52 |
+
dimension = embeddings[0].shape[0]
|
53 |
+
index = faiss.IndexFlatIP(dimension)
|
54 |
+
|
55 |
+
# Normalize embeddings for cosine similarity
|
56 |
+
embeddings_matrix = np.array(embeddings).astype("float32")
|
57 |
+
faiss.normalize_L2(embeddings_matrix)
|
58 |
+
# Add embeddings to index
|
59 |
+
index.add(embeddings_matrix)
|
60 |
+
|
61 |
+
# Save FAISS index and metadata
|
62 |
+
os.makedirs("vector_cache", exist_ok=True)
|
63 |
+
faiss.write_index(index, "vector_cache/faiss_index.bin")
|
64 |
+
|
65 |
+
with open("vector_cache/metadata.pkl", "wb") as f:
|
66 |
+
pickle.dump({"texts": texts, "filenames": filenames}, f)
|
67 |
|
68 |
+
print(f"Saved FAISS index to vector_cache/ with {len(embeddings)} documents.")
|
69 |
print(f"Total embeddings created: {len(embeddings)}")
|
|
|
70 |
|
71 |
+
return list(zip(filenames, embeddings))
|