Ankitajadhav commited on
Commit
ac63cbd
·
verified ·
1 Parent(s): a9d0935

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -27
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # import packages
2
  import shutil
3
  import os
4
  __import__('pysqlite3')
@@ -7,7 +6,6 @@ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
7
  from sentence_transformers import SentenceTransformer
8
  import chromadb
9
  from datasets import load_dataset
10
- # from transformers import AutoModelForCausalLM, AutoTokenizer
11
  import gradio as gr
12
  from transformers import GPT2Tokenizer, GPT2Model
13
 
@@ -19,7 +17,6 @@ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
19
  # Load the model with from_tf=True
20
  model = GPT2Model.from_pretrained(model_name, from_tf=True)
21
 
22
-
23
  # Function to clear the cache
24
  def clear_cache(model_name):
25
  cache_dir = os.path.expanduser(f'~/.cache/torch/sentence_transformers/{model_name.replace("/", "_")}')
@@ -29,12 +26,10 @@ def clear_cache(model_name):
29
  else:
30
  print(f"No cache directory found for: {cache_dir}")
31
 
32
-
33
  # Embedding vector
34
  class VectorStore:
35
  def __init__(self, collection_name):
36
- # Initialize the embedding model
37
- # Initialize the embedding model with try-except block for better error handling
38
  try:
39
  self.embedding_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
40
  except Exception as e:
@@ -46,11 +41,11 @@ class VectorStore:
46
  # Method to populate the vector store with embeddings from a dataset
47
  def populate_vectors(self, dataset, batch_size=100):
48
  # Use dataset streaming
49
- dataset = load_dataset('Thefoodprocessor/recipe_new_with_features_full', split='train[:1500]')
50
 
51
- # Process in batches
52
  texts = []
53
- for i, example in enumerate(dataset):
 
54
  title = example['title_cleaned']
55
  recipe = example['recipe_new']
56
  meal_type = example['meal_type']
@@ -66,6 +61,8 @@ class VectorStore:
66
  self._process_batch(texts, i)
67
  texts = []
68
 
 
 
69
  # Process the remaining texts
70
  if texts:
71
  self._process_batch(texts, i)
@@ -79,24 +76,13 @@ class VectorStore:
79
  query_embeddings = self.embedding_model.encode(query).tolist()
80
  return self.collection.query(query_embeddings=query_embeddings, n_results=n_results)
81
 
82
- # create a vector embedding
83
  vector_store = VectorStore("embedding_vector")
84
  vector_store.populate_vectors(dataset=None)
85
 
86
-
87
- # Load the model and tokenizer
88
- # text generation model
89
- # model_name = "meta-llama/Meta-Llama-3-8B"
90
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
91
- # model = AutoModelForCausalLM.from_pretrained(model_name)
92
-
93
- # load model orca-mini general purpose model
94
- # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
95
- # model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
96
-
97
-
98
-
99
  # Define the chatbot response function
 
 
100
  def chatbot_response(user_input):
101
  global conversation_history
102
  results = vector_store.search_context(user_input, n_results=1)
@@ -108,13 +94,11 @@ def chatbot_response(user_input):
108
  conversation_history.append(response)
109
  return response
110
 
111
-
112
  # Gradio interface
113
  def chat(user_input):
114
  response = chatbot_response(user_input)
115
  return response
 
116
  css = ".gradio-container {background: url(https://upload.wikimedia.org/wikipedia/commons/f/f5/Spring_Kitchen_Line-Up_%28Unsplash%29.jpg)}"
117
- iface = gr.Interface(fn=chat, inputs="text", outputs="text",css=css)
118
  iface.launch()
119
-
120
-
 
 
1
  import shutil
2
  import os
3
  __import__('pysqlite3')
 
6
  from sentence_transformers import SentenceTransformer
7
  import chromadb
8
  from datasets import load_dataset
 
9
  import gradio as gr
10
  from transformers import GPT2Tokenizer, GPT2Model
11
 
 
17
  # Load the model with from_tf=True
18
  model = GPT2Model.from_pretrained(model_name, from_tf=True)
19
 
 
20
  # Function to clear the cache
21
  def clear_cache(model_name):
22
  cache_dir = os.path.expanduser(f'~/.cache/torch/sentence_transformers/{model_name.replace("/", "_")}')
 
26
  else:
27
  print(f"No cache directory found for: {cache_dir}")
28
 
 
29
  # Embedding vector
30
  class VectorStore:
31
  def __init__(self, collection_name):
32
+ # Initialize the embedding model
 
33
  try:
34
  self.embedding_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
35
  except Exception as e:
 
41
  # Method to populate the vector store with embeddings from a dataset
42
  def populate_vectors(self, dataset, batch_size=100):
43
  # Use dataset streaming
44
+ dataset = load_dataset('Thefoodprocessor/recipe_new_with_features_full', split='train[:1500]', streaming=True)
45
 
 
46
  texts = []
47
+ i = 0 # Initialize index
48
+ for example in dataset:
49
  title = example['title_cleaned']
50
  recipe = example['recipe_new']
51
  meal_type = example['meal_type']
 
61
  self._process_batch(texts, i)
62
  texts = []
63
 
64
+ i += 1 # Increment index
65
+
66
  # Process the remaining texts
67
  if texts:
68
  self._process_batch(texts, i)
 
76
  query_embeddings = self.embedding_model.encode(query).tolist()
77
  return self.collection.query(query_embeddings=query_embeddings, n_results=n_results)
78
 
79
+ # Create a vector embedding
80
  vector_store = VectorStore("embedding_vector")
81
  vector_store.populate_vectors(dataset=None)
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Define the chatbot response function
84
+ conversation_history = []
85
+
86
  def chatbot_response(user_input):
87
  global conversation_history
88
  results = vector_store.search_context(user_input, n_results=1)
 
94
  conversation_history.append(response)
95
  return response
96
 
 
97
  # Gradio interface
98
  def chat(user_input):
99
  response = chatbot_response(user_input)
100
  return response
101
+
102
  css = ".gradio-container {background: url(https://upload.wikimedia.org/wikipedia/commons/f/f5/Spring_Kitchen_Line-Up_%28Unsplash%29.jpg)}"
103
+ iface = gr.Interface(fn=chat, inputs="text", outputs="text", css=css)
104
  iface.launch()