File size: 3,812 Bytes
fc67742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from datasets import load_dataset
import pandas as pd

from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore

from pathlib import Path
from functools import reduce
import os

def build_db(openai_api_key):
    data = load_imdb_data()

    (embedder, embedding_model) = create_embedder(openai_api_key)
    build_vectore_store(data, embedder)
    vector_store = load_vector_store(embedder)

    run_test_query(embedding_model, vector_store)

def load_imdb_data():
  print("Loading IMDB dataset")
  dataset = load_dataset("ShubhamChoksi/IMDB_Movies")
  dataset_dict = dataset
  dataset_dict["train"].to_csv('data/imdb.csv')
  print("")

  print("Creating dataframe")
  movies_dataframe = pd.read_csv('data/imdb.csv')
  print(movies_dataframe.head())
  print("")

  print("Loading data from CSV")
  loader = CSVLoader(file_path='data/imdb.csv')
  data = loader.load()
  print("Done loading data...")
  print("Length: " + str(len(data))) # ensure we have actually loaded data into a format LangChain can recognize
  print("Data list type: " + str(type(data)))
  print("Data type: " + str(type(data[0])))
  print(data[0])
  print("")

  print("Calculating total length of data")
  add_length = lambda sum, doc: len(doc.page_content) + sum
  total_length = reduce(add_length, data, 0)
  print("Total number of characters in dataset: " + str(total_length))
  print("Total divided by 1,000: " + str(total_length / 1000))
  print("")

  return data

def create_embedder(openai_api_key):
  embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)

  # Create the embedding store file if it doesn't already exist.
  storeFile = str(Path.cwd() / 'data/embedding-store')
  # storeFilePath = Path(storeFile)
  # if not storeFilePath.exists():
  #   storeFilePath.touch()
  
  # Create the embedder, using a local file store as the backing store.
  store = LocalFileStore(storeFile)
  embedder = CacheBackedEmbeddings.from_bytes_store(
    embedding_model,
    store
  )

  return (embedder, embedding_model)

def build_vectore_store(data, embedder):
  text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=100,
  )
  chunked_documents = text_splitter.split_documents(data)
  len(chunked_documents) # ensure we have actually split the data into chunks

  print("Trying to load vector store from file...")
  vector_store = None
  try:
    os.makedirs("data/week2-movies", exist_ok=True)
    vector_store = FAISS.load_local("data/week2-movies", embedder, allow_dangerous_deserialization=True)
  except Exception as e:
      vector_store = None

  if vector_store is None:
    print("No local vector store found - computing a new one...")
    vector_store = FAISS.from_documents(data, embedder)
    print("Done computing new vectore store. Saving to local file.")
    vector_store.save_local("data/week2-movies")
  else:
    print("Found vector store in local file. Using that.")
  print("")

def load_vector_store(embedder):
   vector_store = FAISS.load_local("data/week2-movies", embedder, allow_dangerous_deserialization=True)
   return vector_store

def run_test_query(embedding_model, vector_store):
  print("Verifying that we can query the vectore dB...")
  query = "I have a need. A need for speed."
  embedded_query = embedding_model.embed_query(query)
  similar_documents = vector_store.similarity_search_by_vector(embedded_query)
  for page in similar_documents:
    print(str(page.page_content))
  print("")


if __name__ == "__main__":
  openai_api_key = os.getenv("openai_api_key")
  build_db(openai_api_key)