| import math | |
| import os | |
| import re | |
| from pathlib import Path | |
| from statistics import median | |
| import pandas as pd | |
| import streamlit as st | |
| from bs4 import BeautifulSoup | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.docstore.document import Document | |
| from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain_openai import ChatOpenAI | |
| from ragatouille import RAGPretrainedModel | |
| st.set_page_config(layout="wide") | |
| os.environ["OPENAI_API_KEY"] = "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" | |
| LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") | |
| deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip() | |
| def embeddings_on_local_vectordb(texts): | |
| colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9") | |
| colbert.index( | |
| collection=[chunk.page_content for chunk in texts], | |
| split_documents=False, | |
| document_metadatas=[chunk.metadata for chunk in texts], | |
| index_name="vector_store", | |
| ) | |
| retriever = colbert.as_langchain_retriever(k=5) | |
| retriever = MultiQueryRetriever.from_llm( | |
| retriever=retriever, llm=ChatOpenAI(temperature=0) | |
| ) | |
| return retriever | |
| def query_llm(retriever, query): | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), | |
| retriever=retriever, | |
| return_source_documents=True, | |
| chain_type="stuff", | |
| ) | |
| relevant_docs = retriever.get_relevant_documents(query) | |
| with get_openai_callback() as cb: | |
| result = qa_chain( | |
| {"question": query, "chat_history": st.session_state.messages} | |
| ) | |
| stats = cb | |
| result = result["answer"] | |
| st.session_state.messages.append((query, result)) | |
| return relevant_docs, result, stats | |
| def input_fields(): | |
| st.session_state.source_doc_urls = [ | |
| url.strip() | |
| for url in st.sidebar.text_area( | |
| "Source Document URLs\n(New line separated)", height=50 | |
| ).split("\n") | |
| ] | |
| def process_documents(): | |
| try: | |
| snippets = [] | |
| for url in st.session_state.source_doc_urls: | |
| if url.endswith(".pdf"): | |
| snippets.extend(process_pdf(url)) | |
| else: | |
| snippets.extend(process_web(url)) | |
| st.session_state.retriever = embeddings_on_local_vectordb(snippets) | |
| st.session_state.headers = pd.Series( | |
| [snip.metadata["header"] for snip in snippets], name="references" | |
| ) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| def process_pdf(url): | |
| data = PDFMinerPDFasHTMLLoader(url).load()[0] | |
| content = BeautifulSoup(data.page_content, "html.parser").find_all("div") | |
| snippets = get_pdf_snippets(content) | |
| filtered_snippets = filter_pdf_snippets(snippets, new_line_threshold_ratio=0.4) | |
| median_font_size = math.ceil( | |
| median([font_size for _, font_size in filtered_snippets]) | |
| ) | |
| semantic_snippets = get_pdf_semantic_snippets(filtered_snippets, median_font_size) | |
| document_snippets = [ | |
| Document( | |
| page_content=deep_strip(snip[1]["header_text"]) + " " + deep_strip(snip[0]), | |
| metadata={ | |
| "header": " ".join(snip[1]["header_text"].split()[:10]), | |
| "source_url": url, | |
| "source_type": "pdf", | |
| "chunk_id": i, | |
| }, | |
| ) | |
| for i, snip in enumerate(semantic_snippets) | |
| ] | |
| return document_snippets | |
| def get_pdf_snippets(content): | |
| current_font_size = None | |
| current_text = "" | |
| snippets = [] | |
| for cntnt in content: | |
| span = cntnt.find("span") | |
| if not span: | |
| continue | |
| style = span.get("style") | |
| if not style: | |
| continue | |
| font_size = re.findall("font-size:(\d+)px", style) | |
| if not font_size: | |
| continue | |
| font_size = int(font_size[0]) | |
| if not current_font_size: | |
| current_font_size = font_size | |
| if font_size == current_font_size: | |
| current_text += cntnt.text | |
| else: | |
| snippets.append((current_text, current_font_size)) | |
| current_font_size = font_size | |
| current_text = cntnt.text | |
| snippets.append((current_text, current_font_size)) | |
| return snippets | |
| def filter_pdf_snippets(content_list, new_line_threshold_ratio): | |
| filtered_list = [] | |
| for e, (content, font_size) in enumerate(content_list): | |
| newline_count = content.count("\n") | |
| total_chars = len(content) | |
| ratio = newline_count / total_chars | |
| if ratio <= new_line_threshold_ratio: | |
| filtered_list.append((content, font_size)) | |
| return filtered_list | |
| def get_pdf_semantic_snippets(filtered_snippets, median_font_size): | |
| semantic_snippets = [] | |
| current_header = None | |
| current_content = [] | |
| header_font_size = None | |
| content_font_sizes = [] | |
| for content, font_size in filtered_snippets: | |
| if font_size > median_font_size: | |
| if current_header is not None: | |
| metadata = { | |
| "header_font_size": header_font_size, | |
| "content_font_size": ( | |
| median(content_font_sizes) if content_font_sizes else None | |
| ), | |
| "header_text": current_header, | |
| } | |
| semantic_snippets.append((current_content, metadata)) | |
| current_content = [] | |
| content_font_sizes = [] | |
| current_header = content | |
| header_font_size = font_size | |
| else: | |
| content_font_sizes.append(font_size) | |
| if current_content: | |
| current_content += " " + content | |
| else: | |
| current_content = content | |
| if current_header is not None: | |
| metadata = { | |
| "header_font_size": header_font_size, | |
| "content_font_size": ( | |
| median(content_font_sizes) if content_font_sizes else None | |
| ), | |
| "header_text": current_header, | |
| } | |
| semantic_snippets.append((current_content, metadata)) | |
| return semantic_snippets | |
| def process_web(url): | |
| data = WebBaseLoader(url).load()[0] | |
| document_snippets = [ | |
| Document( | |
| page_content=deep_strip(data.page_content), | |
| metadata={ | |
| "header": data.metadata["title"], | |
| "source_url": url, | |
| "source_type": "web", | |
| }, | |
| ) | |
| ] | |
| return document_snippets | |
| def boot(): | |
| st.title("Xi Chatbot") | |
| st.sidebar.title("Input Documents") | |
| input_fields() | |
| st.sidebar.button("Submit Documents", on_click=process_documents) | |
| if "headers" in st.session_state: | |
| st.sidebar.write("### References") | |
| st.sidebar.write(st.session_state.headers) | |
| if "costing" not in st.session_state: | |
| st.session_state.costing = [] | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| st.chat_message("human").write(message[0]) | |
| st.chat_message("ai").write(message[1]) | |
| if query := st.chat_input(): | |
| st.chat_message("human").write(query) | |
| references, response, stats = query_llm(st.session_state.retriever, query) | |
| sorted_references = sorted([ref.metadata["chunk_id"] for ref in references]) | |
| references_str = " ".join([f"[{ref}]" for ref in sorted_references]) | |
| st.chat_message("ai").write(response + "\n\n---\nReferences:" + references_str) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "total cost": stats.total_cost, | |
| } | |
| ) | |
| stats_df = pd.DataFrame(st.session_state.costing) | |
| stats_df.loc["total"] = stats_df.sum() | |
| st.sidebar.write(stats_df) | |
| if __name__ == "__main__": | |
| boot() |