Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| from typing import Dict, List, Optional, Union | |
| import gradio as gr | |
| from groq import Groq | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.vectorstores import InMemoryVectorStore | |
| # Retrieve API key for Groq from the environment variables | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
| # Initialize the Groq client | |
| client = Groq(api_key=GROQ_API_KEY) | |
| # Initialize the LLM | |
| llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", api_key=GROQ_API_KEY) | |
| # Initialize the embedding model | |
| embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
| # General constants for the UI | |
| TITLE = """<h1 align="center">✨ Llama 4 RAG Application</h1>""" | |
| AVATAR_IMAGES = ( | |
| None, | |
| "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png", | |
| ) | |
| # List of supported text extensions (alphabetically sorted) | |
| TEXT_EXTENSIONS = [ | |
| ".bat", | |
| ".c", | |
| ".cfg", | |
| ".conf", | |
| ".cpp", | |
| ".cs", | |
| ".css", | |
| ".docx", | |
| ".go", | |
| ".h", | |
| ".html", | |
| ".ini", | |
| ".java", | |
| ".js", | |
| ".json", | |
| ".jsx", | |
| ".md", | |
| ".php", | |
| ".ps1", | |
| ".py", | |
| ".rb", | |
| ".rs", | |
| ".sh", | |
| ".toml", | |
| ".ts", | |
| ".tsx", | |
| ".txt", | |
| ".xml", | |
| ".yaml", | |
| ".yml", | |
| ] | |
| # Global variables | |
| EXTRACTED_FILES = {} | |
| VECTORSTORE = None | |
| RAG_CHAIN = None | |
| # Initialize the text splitter | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, chunk_overlap=100, separators=["\n\n", "\n"] | |
| ) | |
| # Define the RAG prompt template | |
| template = """You are an expert assistant tasked with answering questions based on the provided documents. | |
| Use only the given context to generate your answer. | |
| If the answer cannot be found in the context, clearly state that you do not know. | |
| Be detailed and precise in your response, but avoid mentioning or referencing the context itself. | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer:""" | |
| # Create the PromptTemplate | |
| rag_prompt = PromptTemplate.from_template(template) | |
| def extract_text_from_zip(zip_file_path: str) -> Dict[str, str]: | |
| """ | |
| Extract text content from files in a ZIP archive. | |
| Parameters: | |
| zip_file_path (str): Path to the ZIP file. | |
| Returns: | |
| Dict[str, str]: Dictionary mapping filenames to their text content. | |
| """ | |
| text_contents = {} | |
| with zipfile.ZipFile(zip_file_path, "r") as zip_ref: | |
| for file_info in zip_ref.infolist(): | |
| # Skip directories | |
| if file_info.filename.endswith("/"): | |
| continue | |
| # Skip binary files and focus on text files | |
| file_ext = os.path.splitext(file_info.filename)[1].lower() | |
| if file_ext in TEXT_EXTENSIONS: | |
| try: | |
| with zip_ref.open(file_info) as file: | |
| content = file.read().decode("utf-8", errors="replace") | |
| text_contents[file_info.filename] = content | |
| except Exception as e: | |
| text_contents[file_info.filename] = ( | |
| f"Error extracting file: {str(e)}" | |
| ) | |
| return text_contents | |
| def extract_text_from_single_file(file_path: str) -> Dict[str, str]: | |
| """ | |
| Extract text content from a single file. | |
| Parameters: | |
| file_path (str): Path to the file. | |
| Returns: | |
| Dict[str, str]: Dictionary mapping filename to its text content. | |
| """ | |
| text_contents = {} | |
| filename = os.path.basename(file_path) | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| if file_ext in TEXT_EXTENSIONS: | |
| try: | |
| with open(file_path, "r", encoding="utf-8", errors="replace") as file: | |
| content = file.read() | |
| text_contents[filename] = content | |
| except Exception as e: | |
| text_contents[filename] = f"Error reading file: {str(e)}" | |
| return text_contents | |
| def upload_files( | |
| files: Optional[List[str]], chatbot: List[Union[dict, gr.ChatMessage]] | |
| ): | |
| """ | |
| Process uploaded files (ZIP or single text files): extract text content and append a message to the chat. | |
| Parameters: | |
| files (Optional[List[str]]): List of file paths. | |
| chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
| Returns: | |
| List[Union[dict, gr.ChatMessage]]: Updated conversation history. | |
| """ | |
| global EXTRACTED_FILES, VECTORSTORE, RAG_CHAIN | |
| # Handle multiple file uploads | |
| if len(files) > 1: | |
| total_files_processed = 0 | |
| total_files_extracted = 0 | |
| file_types = set() | |
| # Process each file | |
| for file in files: | |
| filename = os.path.basename(file) | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| # Process based on file type | |
| if file_ext == ".zip": | |
| extracted_files = extract_text_from_zip(file) | |
| file_types.add("zip") | |
| else: | |
| extracted_files = extract_text_from_single_file(file) | |
| file_types.add("text") | |
| if extracted_files: | |
| total_files_extracted += len(extracted_files) | |
| # Store the extracted content in the global variable | |
| EXTRACTED_FILES[filename] = extracted_files | |
| total_files_processed += 1 | |
| # Create a summary message for multiple files | |
| file_types_str = ( | |
| "files" | |
| if len(file_types) > 1 | |
| else ("ZIP files" if "zip" in file_types else "text files") | |
| ) | |
| # Create a list of uploaded file names | |
| file_list = "\n".join([f"- {os.path.basename(file)}" for file in files]) | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="user", | |
| content=f"<p>📚 Multiple {file_types_str} uploaded ({total_files_processed} files)</p><p>Extracted {total_files_extracted} text file(s) in total</p><p>Uploaded files:</p><pre>{file_list}</pre>", | |
| ) | |
| ) | |
| # Handle single file upload | |
| elif len(files) == 1: | |
| file = files[0] | |
| filename = os.path.basename(file) | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| # Process based on file type | |
| if file_ext == ".zip": | |
| extracted_files = extract_text_from_zip(file) | |
| file_type_msg = "📦 ZIP file" | |
| else: | |
| extracted_files = extract_text_from_single_file(file) | |
| file_type_msg = "📄 File" | |
| if not extracted_files: | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="user", | |
| content=f"<p>{file_type_msg} uploaded: {filename}, but no text content was found or the file format is not supported.</p>", | |
| ) | |
| ) | |
| else: | |
| file_list = "\n".join([f"- {name}" for name in extracted_files.keys()]) | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="user", | |
| content=f"<p>{file_type_msg} uploaded: {filename}</p><p>Extracted {len(extracted_files)} text file(s):</p><pre>{file_list}</pre>", | |
| ) | |
| ) | |
| # Store the extracted content in the global variable | |
| EXTRACTED_FILES[filename] = extracted_files | |
| # Process the extracted files and create vector embeddings | |
| if EXTRACTED_FILES: | |
| # Prepare documents for processing | |
| all_texts = [] | |
| for filename, files in EXTRACTED_FILES.items(): | |
| for file_path, content in files.items(): | |
| all_texts.append( | |
| {"page_content": content, "metadata": {"source": file_path}} | |
| ) | |
| # Create document objects | |
| from langchain_core.documents import Document | |
| documents = [ | |
| Document(page_content=item["page_content"], metadata=item["metadata"]) | |
| for item in all_texts | |
| ] | |
| # Split the documents into chunks | |
| chunks = text_splitter.split_documents(documents) | |
| # Create the vector store | |
| VECTORSTORE = InMemoryVectorStore.from_documents( | |
| documents=chunks, | |
| embedding=embed_model, | |
| ) | |
| # Create the retriever | |
| retriever = VECTORSTORE.as_retriever() | |
| # Create the RAG chain | |
| RAG_CHAIN = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | rag_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # Add a confirmation message | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="Documents processed and indexed. You can now ask questions about the content.", | |
| ) | |
| ) | |
| return chatbot | |
| def user(text_prompt: str, chatbot: List[gr.ChatMessage]): | |
| """ | |
| Append a new user text message to the chat history. | |
| Parameters: | |
| text_prompt (str): The input text provided by the user. | |
| chatbot (List[gr.ChatMessage]): The existing conversation history. | |
| Returns: | |
| Tuple[str, List[gr.ChatMessage]]: A tuple of an empty string (clearing the prompt) | |
| and the updated conversation history. | |
| """ | |
| if text_prompt: | |
| chatbot.append(gr.ChatMessage(role="user", content=text_prompt)) | |
| return "", chatbot | |
| def get_message_content(msg): | |
| """ | |
| Retrieve the content of a message that can be either a dictionary or a gr.ChatMessage. | |
| Parameters: | |
| msg (Union[dict, gr.ChatMessage]): The message object. | |
| Returns: | |
| str: The textual content of the message. | |
| """ | |
| if isinstance(msg, dict): | |
| return msg.get("content", "") | |
| return msg.content | |
| def process_query(chatbot: List[Union[dict, gr.ChatMessage]]): | |
| """ | |
| Process the user's query using the RAG pipeline. | |
| Parameters: | |
| chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
| Returns: | |
| List[Union[dict, gr.ChatMessage]]: The updated conversation history with the response. | |
| """ | |
| global RAG_CHAIN | |
| if len(chatbot) == 0: | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="Please enter a question or upload documents to start the conversation.", | |
| ) | |
| ) | |
| return chatbot | |
| # Get the last user message as the prompt | |
| user_messages = [ | |
| msg | |
| for msg in chatbot | |
| if (isinstance(msg, dict) and msg.get("role") == "user") | |
| or (hasattr(msg, "role") and msg.role == "user") | |
| ] | |
| if not user_messages: | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="Please enter a question to start the conversation.", | |
| ) | |
| ) | |
| return chatbot | |
| last_user_msg = user_messages[-1] | |
| prompt = get_message_content(last_user_msg) | |
| # Skip if the last message was about uploading a file | |
| if ( | |
| "📦 ZIP file uploaded:" in prompt | |
| or "📄 File uploaded:" in prompt | |
| or "📚 Multiple files uploaded" in prompt | |
| ): | |
| return chatbot | |
| # Check if RAG chain is available | |
| if RAG_CHAIN is None: | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="Please upload documents first to enable question answering.", | |
| ) | |
| ) | |
| return chatbot | |
| # Append a placeholder for the assistant's response | |
| chatbot.append(gr.ChatMessage(role="assistant", content="Thinking...")) | |
| try: | |
| # Process the query through the RAG chain | |
| response = RAG_CHAIN.invoke(prompt) | |
| # Update the placeholder with the actual response | |
| chatbot[-1].content = response | |
| except Exception as e: | |
| # Handle any errors | |
| chatbot[-1].content = f"Error processing your query: {str(e)}" | |
| return chatbot | |
| def reset_app(chatbot): | |
| """ | |
| Reset the app by clearing the chat context and removing any uploaded files. | |
| Parameters: | |
| chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
| Returns: | |
| List[Union[dict, gr.ChatMessage]]: A fresh conversation history. | |
| """ | |
| global EXTRACTED_FILES, VECTORSTORE, RAG_CHAIN | |
| # Clear the global variables | |
| EXTRACTED_FILES = {} | |
| VECTORSTORE = None | |
| RAG_CHAIN = None | |
| # Reset the chatbot with a welcome message | |
| return [ | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="App has been reset. You can start a new conversation or upload new documents.", | |
| ) | |
| ] | |
| # Define the Gradio UI components | |
| chatbot_component = gr.Chatbot( | |
| label="Llama 4 RAG", | |
| type="messages", | |
| bubble_full_width=False, | |
| avatar_images=AVATAR_IMAGES, | |
| scale=2, | |
| height=350, | |
| ) | |
| text_prompt_component = gr.Textbox( | |
| placeholder="Ask a question about your documents...", | |
| show_label=False, | |
| autofocus=True, | |
| scale=28, | |
| ) | |
| upload_files_button_component = gr.UploadButton( | |
| label="Upload", | |
| file_count="multiple", | |
| file_types=[".zip", ".docx"] + TEXT_EXTENSIONS, | |
| scale=1, | |
| min_width=80, | |
| ) | |
| send_button_component = gr.Button( | |
| value="Send", variant="primary", scale=1, min_width=80 | |
| ) | |
| reset_button_component = gr.Button(value="Reset", variant="stop", scale=1, min_width=80) | |
| # Define input lists for button chaining | |
| user_inputs = [text_prompt_component, chatbot_component] | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| gr.HTML(TITLE) | |
| with gr.Column(): | |
| chatbot_component.render() | |
| with gr.Row(equal_height=True): | |
| text_prompt_component.render() | |
| send_button_component.render() | |
| upload_files_button_component.render() | |
| reset_button_component.render() | |
| # When the Send button is clicked, first process the user text then process the query | |
| send_button_component.click( | |
| fn=user, | |
| inputs=user_inputs, | |
| outputs=[text_prompt_component, chatbot_component], | |
| queue=False, | |
| ).then( | |
| fn=process_query, | |
| inputs=[chatbot_component], | |
| outputs=[chatbot_component], | |
| api_name="process_query", | |
| ) | |
| # Allow submission using the Enter key | |
| text_prompt_component.submit( | |
| fn=user, | |
| inputs=user_inputs, | |
| outputs=[text_prompt_component, chatbot_component], | |
| queue=False, | |
| ).then( | |
| fn=process_query, | |
| inputs=[chatbot_component], | |
| outputs=[chatbot_component], | |
| api_name="process_query_submit", | |
| ) | |
| # Handle file uploads | |
| upload_files_button_component.upload( | |
| fn=upload_files, | |
| inputs=[upload_files_button_component, chatbot_component], | |
| outputs=[chatbot_component], | |
| queue=False, | |
| ) | |
| # Handle Reset button clicks | |
| reset_button_component.click( | |
| fn=reset_app, | |
| inputs=[chatbot_component], | |
| outputs=[chatbot_component], | |
| queue=False, | |
| ) | |
| # Launch the demo interface | |
| demo.queue().launch() | |