Spaces:
Sleeping
Sleeping
| from chatbot.llm import gemini_llm # Import Gemini LLM | |
| from chatbot.memory import memory | |
| from chatbot.prompts import chat_prompt | |
| from langchain.retrievers import WikipediaRetriever | |
| from langchain.chains import ConversationalRetrievalChain | |
| from pydantic import Field | |
| from typing import List, Callable | |
| from langchain.schema import BaseRetriever, Document | |
| def translate_to_english(text: str) -> str: | |
| """Use Gemini LLM to translate text to English.""" | |
| prompt = f""" | |
| You are an assistant for Wikipedia searches. | |
| The query may be in any language. | |
| Extract and return only the most relevant keyword (e.g. a person's name, city, or key term) in English/international form. | |
| Return only the keyword—no explanations. | |
| Query: | |
| {text} | |
| """ | |
| response = gemini_llm.invoke(prompt) # Invoke Gemini for translation | |
| return response # Assuming `gemini_llm.invoke()` returns plain text | |
| class WikipediaTranslationRetriever(BaseRetriever): | |
| retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever") | |
| translator: Callable[[str], str] = Field(..., description="Function to translate queries to English") | |
| def get_relevant_documents(self, query: str) -> List[Document]: | |
| translated_query = self.translator(query) | |
| print(f"🔄 Translated Query: {translated_query}") | |
| return self.retriever.get_relevant_documents(translated_query) | |
| async def aget_relevant_documents(self, query: str) -> List[Document]: | |
| # For simplicity, we are not implementing the async version. | |
| raise NotImplementedError("Async retrieval is not implemented.") | |
| # Create the retriever instance to be used in your qa_chain: | |
| retriever = WikipediaTranslationRetriever( | |
| retriever=WikipediaRetriever(), | |
| translator=translate_to_english | |
| ) | |
| # ✅ Use ConversationalRetrievalChain | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=gemini_llm, | |
| retriever=retriever, | |
| memory=memory, | |
| return_source_documents=False, | |
| combine_docs_chain_kwargs={"prompt": chat_prompt}, | |
| output_key="result" | |
| ) | |
| def get_chat_response(user_input: str) -> str: | |
| """Process user input and return chat response using Wikipedia retrieval.""" | |
| response = qa_chain(user_input) # Pass query to retrieval-based QA chain | |
| # Save conversation context | |
| memory.save_context({"input": user_input}, {"output": response["result"]}) | |
| return response["result"] | |