KarthikaRajagopal commited on
Commit
e21d6f2
·
verified ·
1 Parent(s): 3784c98

Upload RAG_using_Llama3.py.py

Browse files
Files changed (1) hide show
  1. RAG_using_Llama3.py.py +153 -0
RAG_using_Llama3.py.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """RAG_using_Llama3.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1b-ZDo3QQ-axgm804UlHu3ohZwnoXz5L1
8
+
9
+ # install dependecies
10
+ """
11
+
12
+ !pip install -q datasets sentence-transformers faiss-cpu accelerate
13
+
14
+ from huggingface_hub import notebook_login
15
+ notebook_login()
16
+
17
+ """# embed dataset
18
+
19
+ this is a slow procedure so you might consider saving your results
20
+ """
21
+
22
+ from datasets import load_dataset
23
+
24
+ dataset = load_dataset("KarthikaRajagopal/wikipedia-2")
25
+
26
+ dataset
27
+
28
+ from sentence_transformers import SentenceTransformer
29
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
30
+
31
+ # embed the dataset
32
+ def embed(batch):
33
+ # or you can combine multiple columns here, for example the title and the text
34
+ information = batch["text"]
35
+ return {"embeddings" : ST.encode(information)}
36
+ dataset = dataset.map(embed,batched=True,batch_size=16)
37
+
38
+ !pip install datasets
39
+
40
+ from datasets import load_dataset
41
+
42
+ dataset = load_dataset("KarthikaRajagopal/wikipedia-2",revision = "embedded")
43
+
44
+ # Push it to your Hugging Face repository
45
+ dataset.push_to_hub("KarthikaRajagopal/wikipedia-2", revision="embedded")
46
+
47
+ from datasets import load_dataset
48
+
49
+ dataset = load_dataset("KarthikaRajagopal/wikipedia-2",revision = "embedded")
50
+
51
+ data = dataset["train"]
52
+ data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
53
+
54
+ def search(query: str, k: int = 3 ):
55
+ """a function that embeds a new query and returns the most probable results"""
56
+ embedded_query = ST.encode(query) # embed new query
57
+ scores, retrieved_examples = data.get_nearest_examples( # retrieve results
58
+ "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
59
+ k=k # get only top k results
60
+ )
61
+ return scores, retrieved_examples
62
+
63
+ scores , result = search("anarchy", 4 ) # search for word anarchy and get the best 4 matching values from the dataset
64
+
65
+ # the lower the better
66
+ scores
67
+
68
+ result['title']
69
+
70
+ print(result["text"][0])
71
+
72
+ """# chatbot on top of the retrieved results"""
73
+
74
+ !pip install -q datasets sentence-transformers faiss-cpu accelerate bitsandbytes
75
+
76
+ from sentence_transformers import SentenceTransformer
77
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
78
+
79
+ from datasets import load_dataset
80
+
81
+ dataset = load_dataset("KarthikaRajagopal/wikipedia-2",revision = "embedded")
82
+
83
+ data = dataset["train"]
84
+ data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
85
+
86
+ def search(query: str, k: int = 3 ):
87
+ """a function that embeds a new query and returns the most probable results"""
88
+ embedded_query = ST.encode(query) # embed new query
89
+ scores, retrieved_examples = data.get_nearest_examples( # retrieve results
90
+ "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
91
+ k=k # get only top k results
92
+ )
93
+ return scores, retrieved_examples
94
+
95
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
96
+ import torch
97
+
98
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
99
+
100
+ bnb_config = BitsAndBytesConfig(
101
+ load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
102
+ )
103
+
104
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
105
+ model = AutoModelForCausalLM.from_pretrained(
106
+ model_id,
107
+ torch_dtype=torch.bfloat16,
108
+ device_map="auto",
109
+ quantization_config=bnb_config
110
+ )
111
+ terminators = [
112
+ tokenizer.eos_token_id,
113
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
114
+ ]
115
+
116
+ SYS_PROMPT = """You are an assistant for answering questions.
117
+ You are given the extracted parts of a long document and a question. Provide a conversational answer.
118
+ If you don't know the answer, just say "I do not know." Don't make up an answer."""
119
+
120
+ def format_prompt(prompt,retrieved_documents,k):
121
+ """using the retrieved documents we will prompt the model to generate our responses"""
122
+ PROMPT = f"Question:{prompt}\nContext:"
123
+ for idx in range(k) :
124
+ PROMPT+= f"{retrieved_documents['text'][idx]}\n"
125
+ return PROMPT
126
+
127
+ def generate(formatted_prompt):
128
+ formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
129
+ messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
130
+ # tell the model to generate
131
+ input_ids = tokenizer.apply_chat_template(
132
+ messages,
133
+ add_generation_prompt=True,
134
+ return_tensors="pt"
135
+ ).to(model.device)
136
+ outputs = model.generate(
137
+ input_ids,
138
+ max_new_tokens=1024,
139
+ eos_token_id=terminators,
140
+ do_sample=True,
141
+ temperature=0.6,
142
+ top_p=0.9,
143
+ )
144
+ response = outputs[0][input_ids.shape[-1]:]
145
+ return tokenizer.decode(response, skip_special_tokens=True)
146
+
147
+ def rag_chatbot(prompt:str,k:int=2):
148
+ scores , retrieved_documents = search(prompt, k)
149
+ formatted_prompt = format_prompt(prompt,retrieved_documents,k)
150
+ return generate(formatted_prompt)
151
+
152
+ rag_chatbot("what's anarchy ?", k = 2)
153
+