Spaces:
Runtime error
Runtime error
| # from retriever.vectordb import search_documents | |
| from retriever.vectordb_rerank import search_documents | |
| from retriever.vectordb_rerank_law import search_documents as search_law | |
| from retriever.vectordb_rerank_exam import search_documents as search_exam | |
| from generator.prompt_builder import build_prompt | |
| from generator.prompt_builder_all import build_prompt as build_prompt_all | |
| from generator.llm_inference import generate_answer as generate_answer | |
| from generator.llm_inference_all import generate_answer as generate_answer_all | |
| # 2. ์บ์ ๊ด๋ฆฌ | |
| search_cache = {} | |
| def rag_pipeline(query: str, top_k: int = 5) -> str: | |
| """ | |
| 1. ์ฌ์ฉ์ ์ง๋ฌธ์ผ๋ก ๊ด๋ จ ๋ฌธ์๋ฅผ ๊ฒ์ | |
| 2. ๊ฒ์๋ ๋ฌธ์์ ํจ๊ป ํ๋กฌํํธ ๊ตฌ์ฑ | |
| 3. ํ๋กฌํํธ๋ก๋ถํฐ ๋ต๋ณ ์์ฑ | |
| """ | |
| # ์บ์ ํ์ธ | |
| if query in search_cache: | |
| print(f"โก ์บ์ ์ฌ์ฉ: '{query}'") | |
| return search_cache[query] | |
| # 1. ๊ฒ์ | |
| # context_docs = search_documents(query, top_k=top_k) | |
| # print("context_docs: ", context_docs) | |
| # print("==============================================\n\n") | |
| context_exam_docs = search_exam(query, top_k=top_k) | |
| print("context_exam_docs: ", context_exam_docs) | |
| print("==============================================\n\n") | |
| constext_law_docs = search_law(query, top_k=top_k) | |
| print("context_law_docs: ", constext_law_docs) | |
| print("==============================================\n\n") | |
| # 2. ํ๋กฌํํธ ์กฐ๋ฆฝ | |
| prompt = build_prompt_all(query, constext_law_docs, context_exam_docs) | |
| print("prompt: ", prompt) | |
| print("==============================================\n\n") | |
| # 3. ๋ชจ๋ธ ์ถ๋ก | |
| output = generate_answer(prompt) | |
| # return output | |
| if isinstance(context_exam_docs, list): | |
| context_exam_docs = "\n\n".join(context_exam_docs) | |
| search_cache[query] = output | |
| return output | |
| # ์์ ์ฟผ๋ฆฌ | |
| if __name__ == "__main__": | |
| query = "์ค๊ฐ์ ์๊ฐ ์ฌ๋ฌด์๋ฅผ ์ฎ๊ฒผ์ ๋ ํ์ํ ์กฐ์น" | |
| top_k = 5 | |
| result = rag_pipeline(query, top_k) | |
| print(result) |