rag / z_generate.py
Deepak Sahu
image retrieval
f0c8373
raw
history blame
3.63 kB
from huggingface_hub import InferenceClient
import os
class ServerlessInference:
def __init__(self, vector_store_text = None, vector_store_images = None):
self.model:str = "HuggingFaceH4/zephyr-7b-beta"
self.client = InferenceClient(api_key=os.getenv("HF_SERVELESS_API"))
self.vs_text = vector_store_text
self.vs_images = vector_store_images
def test(self, query:str) -> str:
'''Responds to query using llm'''
messages:list = [
{
"role": "user",
"content": query
}
]
completion = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=500
)
return completion.choices[0].message.content
def perform_rag(self, query:str):
# First perform text search
# Retrieval
retrieved_docs = self.vs_text.similarity_search(query=query, k=5)
retrieved_docs_text = [doc.page_content for doc in retrieved_docs] # We only need the text of the documents
context = "\nExtracted documents:\n"
context += "".join([f"Document {str(i)}:::\n" + doc for i, doc in enumerate(retrieved_docs_text)])
# Augmented Generation
messages:list = [
{
"role": "system",
"content": """Using the information contained in the context,
give a comprehensive answer to the question.
Respond only to the question asked, response should be concise and relevant to the question.
If the answer cannot be deduced from the context, do not give an answer. Instead say `Theres lack of information in document source.`""",
},
{
"role": "user",
"content": """Context:
{context}
---
Now here is the question you need to answer.
Question: {question}""".format(context=context, question=query),
},
]
completion = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=500
)
response_text = completion.choices[0].message.content
# Image retrieval
retrieved_image = self.vs_images.similarity_search(query=query, k=5)
retrieved_docs_text = [doc.page_content for doc in retrieved_image] # We only need the text of the documents
context = "\nExtracted Images:\n"
context += "".join([f"Document {str(i)}:::\n" + doc for i, doc in enumerate(retrieved_docs_text)])
messages:list = [
{
"role": "system",
"content": """Using the information contained in the context about the images stored in the database,
give a list of identifiers of the image that best represent the kind of information seeked by the question.
Respond only to the question asked. Provide only number(s) of the source images relevant to the question.
If the image is relevant to the question then output format should be a list [1, 3, 0]
otherwise reply with [] (empty list)""",
},
{
"role": "user",
"content": """Context:
{context}
---
Now here is the question you need to answer.
Question: {question}""".format(context=context, question=query),
},
]
completion = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=500
)
images_list = completion.choices[0].message.content
return response_text + str(images_list)