Spaces:
Sleeping
Sleeping
Update chatbot/core.py
Browse files- chatbot/core.py +16 -15
chatbot/core.py
CHANGED
|
@@ -3,7 +3,9 @@ from chatbot.memory import memory
|
|
| 3 |
from chatbot.prompts import chat_prompt
|
| 4 |
from langchain.retrievers import WikipediaRetriever
|
| 5 |
from langchain.chains import ConversationalRetrievalChain
|
| 6 |
-
from
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def translate_to_english(text: str) -> str:
|
| 9 |
"""Use Gemini LLM to translate text to English."""
|
|
@@ -12,26 +14,25 @@ def translate_to_english(text: str) -> str:
|
|
| 12 |
return response # Assuming `gemini_llm.invoke()` returns plain text
|
| 13 |
|
| 14 |
class WikipediaTranslationRetriever(BaseRetriever):
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
self.retriever = retriever
|
| 18 |
-
self.translator = translator
|
| 19 |
|
| 20 |
-
def get_relevant_documents(self, query):
|
| 21 |
-
translated_query = self.translator(query)
|
| 22 |
print(f"🔄 Translated Query: {translated_query}")
|
| 23 |
return self.retriever.get_relevant_documents(translated_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
async def aget_relevant_documents(self, query):
|
| 26 |
-
# If your environment doesn't need async support, you can simply raise an error.
|
| 27 |
-
raise NotImplementedError("Asynchronous retrieval is not implemented.")
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
# ✅ Use WikipediaRetriever
|
| 31 |
-
wiki_retriever = WikipediaRetriever()
|
| 32 |
-
|
| 33 |
-
# ✅ Wrap with translation
|
| 34 |
-
retriever = WikipediaTranslationRetriever(wiki_retriever, translate_to_english)
|
| 35 |
|
| 36 |
# ✅ Use ConversationalRetrievalChain
|
| 37 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
|
|
|
| 3 |
from chatbot.prompts import chat_prompt
|
| 4 |
from langchain.retrievers import WikipediaRetriever
|
| 5 |
from langchain.chains import ConversationalRetrievalChain
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
from typing import List, Callable
|
| 8 |
+
from langchain.schema import BaseRetriever, Document
|
| 9 |
|
| 10 |
def translate_to_english(text: str) -> str:
|
| 11 |
"""Use Gemini LLM to translate text to English."""
|
|
|
|
| 14 |
return response # Assuming `gemini_llm.invoke()` returns plain text
|
| 15 |
|
| 16 |
class WikipediaTranslationRetriever(BaseRetriever):
|
| 17 |
+
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
|
| 18 |
+
translator: Callable[[str], str] = Field(..., description="Function to translate queries to English")
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
| 21 |
+
translated_query = self.translator(query)
|
| 22 |
print(f"🔄 Translated Query: {translated_query}")
|
| 23 |
return self.retriever.get_relevant_documents(translated_query)
|
| 24 |
+
|
| 25 |
+
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
| 26 |
+
# For simplicity, we are not implementing the async version.
|
| 27 |
+
raise NotImplementedError("Async retrieval is not implemented.")
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
# Create the retriever instance to be used in your qa_chain:
|
| 31 |
+
retriever = WikipediaTranslationRetriever(
|
| 32 |
+
retriever=WikipediaRetriever(),
|
| 33 |
+
translator=translate_to_english
|
| 34 |
+
)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# ✅ Use ConversationalRetrievalChain
|
| 38 |
qa_chain = ConversationalRetrievalChain.from_llm(
|