from chatbot.llm import gemini_llm # Import Gemini LLM from chatbot.memory import memory from chatbot.prompts import chat_prompt from langchain.retrievers import WikipediaRetriever from langchain.chains import ConversationalRetrievalChain from pydantic import Field from typing import List, Callable from langchain.schema import BaseRetriever, Document def translate_to_english(text: str) -> str: """Use Gemini LLM to translate text to English.""" prompt = f"Translate the following text to English:\n\n{text}" response = gemini_llm.invoke(prompt) # Invoke Gemini for translation return response # Assuming `gemini_llm.invoke()` returns plain text class WikipediaTranslationRetriever(BaseRetriever): retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever") translator: Callable[[str], str] = Field(..., description="Function to translate queries to English") def get_relevant_documents(self, query: str) -> List[Document]: translated_query = self.translator(query) print(f"🔄 Translated Query: {translated_query}") return self.retriever.get_relevant_documents(translated_query) async def aget_relevant_documents(self, query: str) -> List[Document]: # For simplicity, we are not implementing the async version. raise NotImplementedError("Async retrieval is not implemented.") # Create the retriever instance to be used in your qa_chain: retriever = WikipediaTranslationRetriever( retriever=WikipediaRetriever(), translator=translate_to_english ) # ✅ Use ConversationalRetrievalChain qa_chain = ConversationalRetrievalChain.from_llm( llm=gemini_llm, retriever=retriever, memory=memory, return_source_documents=False, combine_docs_chain_kwargs={"prompt": chat_prompt}, output_key="result" ) def get_chat_response(user_input: str) -> str: """Process user input and return chat response using Wikipedia retrieval.""" response = qa_chain(user_input) # Pass query to retrieval-based QA chain # Save conversation context memory.save_context({"input": user_input}, {"output": response["result"]}) return response["result"]