Spaces:
Sleeping
Sleeping
Update chatbot/core.py
Browse files- chatbot/core.py +30 -9
chatbot/core.py
CHANGED
|
@@ -1,15 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from langchain.retrievers import WikipediaRetriever
|
| 2 |
-
from chatbot.llm import gemini_llm
|
| 3 |
-
from chatbot.memory import memory
|
| 4 |
-
from chatbot.prompts import chat_prompt
|
| 5 |
from langchain.chains import ConversationalRetrievalChain
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 11 |
llm=gemini_llm,
|
| 12 |
-
retriever=retriever,
|
| 13 |
memory=memory,
|
| 14 |
return_source_documents=False,
|
| 15 |
combine_docs_chain_kwargs={"prompt": chat_prompt},
|
|
@@ -17,10 +38,10 @@ qa_chain = ConversationalRetrievalChain.from_llm(
|
|
| 17 |
)
|
| 18 |
|
| 19 |
def get_chat_response(user_input: str) -> str:
|
| 20 |
-
response
|
|
|
|
| 21 |
|
| 22 |
-
#
|
| 23 |
memory.save_context({"input": user_input}, {"output": response["result"]})
|
| 24 |
|
| 25 |
return response["result"]
|
| 26 |
-
|
|
|
|
| 1 |
+
from chatbot.llm import gemini_llm # Import Gemini LLM
|
| 2 |
+
from chatbot.memory import memory
|
| 3 |
+
from chatbot.prompts import chat_prompt
|
| 4 |
from langchain.retrievers import WikipediaRetriever
|
|
|
|
|
|
|
|
|
|
| 5 |
from langchain.chains import ConversationalRetrievalChain
|
| 6 |
|
| 7 |
+
def translate_to_english(text: str) -> str:
|
| 8 |
+
"""Use Gemini LLM to translate text to English."""
|
| 9 |
+
prompt = f"Translate the following text to English:\n\n{text}"
|
| 10 |
+
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
|
| 11 |
+
return response # Assuming `gemini_llm.invoke()` returns plain text
|
| 12 |
|
| 13 |
+
class WikipediaTranslationRetriever:
|
| 14 |
+
"""Custom Retriever that translates queries before searching Wikipedia."""
|
| 15 |
+
def __init__(self, retriever, translator):
|
| 16 |
+
self.retriever = retriever
|
| 17 |
+
self.translator = translator
|
| 18 |
+
|
| 19 |
+
def get_relevant_documents(self, query):
|
| 20 |
+
translated_query = self.translator(query) # Translate query to English
|
| 21 |
+
print(f"🔄 Translated Query: {translated_query}")
|
| 22 |
+
return self.retriever.get_relevant_documents(translated_query)
|
| 23 |
+
|
| 24 |
+
# ✅ Use WikipediaRetriever
|
| 25 |
+
wiki_retriever = WikipediaRetriever()
|
| 26 |
+
|
| 27 |
+
# ✅ Wrap with translation
|
| 28 |
+
retriever = WikipediaTranslationRetriever(wiki_retriever, translate_to_english)
|
| 29 |
+
|
| 30 |
+
# ✅ Use ConversationalRetrievalChain
|
| 31 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 32 |
llm=gemini_llm,
|
| 33 |
+
retriever=retriever,
|
| 34 |
memory=memory,
|
| 35 |
return_source_documents=False,
|
| 36 |
combine_docs_chain_kwargs={"prompt": chat_prompt},
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
def get_chat_response(user_input: str) -> str:
|
| 41 |
+
"""Process user input and return chat response using Wikipedia retrieval."""
|
| 42 |
+
response = qa_chain(user_input) # Pass query to retrieval-based QA chain
|
| 43 |
|
| 44 |
+
# Save conversation context
|
| 45 |
memory.save_context({"input": user_input}, {"output": response["result"]})
|
| 46 |
|
| 47 |
return response["result"]
|
|
|