Spaces:
Running
Running
import os | |
import zipfile | |
from typing import Dict, List, Optional, Union | |
import gradio as gr | |
from groq import Groq | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_core.vectorstores import InMemoryVectorStore | |
# Retrieve API key for Groq from the environment variables | |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
# Initialize the Groq client | |
client = Groq(api_key=GROQ_API_KEY) | |
# Initialize the LLM | |
llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", api_key=GROQ_API_KEY) | |
# Initialize the embedding model | |
embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
# General constants for the UI | |
TITLE = """<h1 align="center">β¨ Llama 4 RAG Application</h1>""" | |
AVATAR_IMAGES = ( | |
None, | |
"https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png", | |
) | |
# List of supported text extensions (alphabetically sorted) | |
TEXT_EXTENSIONS = [ | |
".bat", | |
".c", | |
".cfg", | |
".conf", | |
".cpp", | |
".cs", | |
".css", | |
".docx", | |
".go", | |
".h", | |
".html", | |
".ini", | |
".java", | |
".js", | |
".json", | |
".jsx", | |
".md", | |
".php", | |
".ps1", | |
".py", | |
".rb", | |
".rs", | |
".sh", | |
".toml", | |
".ts", | |
".tsx", | |
".txt", | |
".xml", | |
".yaml", | |
".yml", | |
] | |
# Global variables | |
EXTRACTED_FILES = {} | |
VECTORSTORE = None | |
RAG_CHAIN = None | |
# Initialize the text splitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, chunk_overlap=100, separators=["\n\n", "\n"] | |
) | |
# Define the RAG prompt template | |
template = """You are an expert assistant tasked with answering questions based on the provided documents. | |
Use only the given context to generate your answer. | |
If the answer cannot be found in the context, clearly state that you do not know. | |
Be detailed and precise in your response, but avoid mentioning or referencing the context itself. | |
Context: | |
{context} | |
Question: | |
{question} | |
Answer:""" | |
# Create the PromptTemplate | |
rag_prompt = PromptTemplate.from_template(template) | |
def extract_text_from_zip(zip_file_path: str) -> Dict[str, str]: | |
""" | |
Extract text content from files in a ZIP archive. | |
Parameters: | |
zip_file_path (str): Path to the ZIP file. | |
Returns: | |
Dict[str, str]: Dictionary mapping filenames to their text content. | |
""" | |
text_contents = {} | |
with zipfile.ZipFile(zip_file_path, "r") as zip_ref: | |
for file_info in zip_ref.infolist(): | |
# Skip directories | |
if file_info.filename.endswith("/"): | |
continue | |
# Skip binary files and focus on text files | |
file_ext = os.path.splitext(file_info.filename)[1].lower() | |
if file_ext in TEXT_EXTENSIONS: | |
try: | |
with zip_ref.open(file_info) as file: | |
content = file.read().decode("utf-8", errors="replace") | |
text_contents[file_info.filename] = content | |
except Exception as e: | |
text_contents[file_info.filename] = ( | |
f"Error extracting file: {str(e)}" | |
) | |
return text_contents | |
def extract_text_from_single_file(file_path: str) -> Dict[str, str]: | |
""" | |
Extract text content from a single file. | |
Parameters: | |
file_path (str): Path to the file. | |
Returns: | |
Dict[str, str]: Dictionary mapping filename to its text content. | |
""" | |
text_contents = {} | |
filename = os.path.basename(file_path) | |
file_ext = os.path.splitext(filename)[1].lower() | |
if file_ext in TEXT_EXTENSIONS: | |
try: | |
with open(file_path, "r", encoding="utf-8", errors="replace") as file: | |
content = file.read() | |
text_contents[filename] = content | |
except Exception as e: | |
text_contents[filename] = f"Error reading file: {str(e)}" | |
return text_contents | |
def upload_files( | |
files: Optional[List[str]], chatbot: List[Union[dict, gr.ChatMessage]] | |
): | |
""" | |
Process uploaded files (ZIP or single text files): extract text content and append a message to the chat. | |
Parameters: | |
files (Optional[List[str]]): List of file paths. | |
chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Returns: | |
List[Union[dict, gr.ChatMessage]]: Updated conversation history. | |
""" | |
global EXTRACTED_FILES, VECTORSTORE, RAG_CHAIN | |
# Handle multiple file uploads | |
if len(files) > 1: | |
total_files_processed = 0 | |
total_files_extracted = 0 | |
file_types = set() | |
# Process each file | |
for file in files: | |
filename = os.path.basename(file) | |
file_ext = os.path.splitext(filename)[1].lower() | |
# Process based on file type | |
if file_ext == ".zip": | |
extracted_files = extract_text_from_zip(file) | |
file_types.add("zip") | |
else: | |
extracted_files = extract_text_from_single_file(file) | |
file_types.add("text") | |
if extracted_files: | |
total_files_extracted += len(extracted_files) | |
# Store the extracted content in the global variable | |
EXTRACTED_FILES[filename] = extracted_files | |
total_files_processed += 1 | |
# Create a summary message for multiple files | |
file_types_str = ( | |
"files" | |
if len(file_types) > 1 | |
else ("ZIP files" if "zip" in file_types else "text files") | |
) | |
# Create a list of uploaded file names | |
file_list = "\n".join([f"- {os.path.basename(file)}" for file in files]) | |
chatbot.append( | |
gr.ChatMessage( | |
role="user", | |
content=f"<p>π Multiple {file_types_str} uploaded ({total_files_processed} files)</p><p>Extracted {total_files_extracted} text file(s) in total</p><p>Uploaded files:</p><pre>{file_list}</pre>", | |
) | |
) | |
# Handle single file upload | |
elif len(files) == 1: | |
file = files[0] | |
filename = os.path.basename(file) | |
file_ext = os.path.splitext(filename)[1].lower() | |
# Process based on file type | |
if file_ext == ".zip": | |
extracted_files = extract_text_from_zip(file) | |
file_type_msg = "π¦ ZIP file" | |
else: | |
extracted_files = extract_text_from_single_file(file) | |
file_type_msg = "π File" | |
if not extracted_files: | |
chatbot.append( | |
gr.ChatMessage( | |
role="user", | |
content=f"<p>{file_type_msg} uploaded: {filename}, but no text content was found or the file format is not supported.</p>", | |
) | |
) | |
else: | |
file_list = "\n".join([f"- {name}" for name in extracted_files.keys()]) | |
chatbot.append( | |
gr.ChatMessage( | |
role="user", | |
content=f"<p>{file_type_msg} uploaded: {filename}</p><p>Extracted {len(extracted_files)} text file(s):</p><pre>{file_list}</pre>", | |
) | |
) | |
# Store the extracted content in the global variable | |
EXTRACTED_FILES[filename] = extracted_files | |
# Process the extracted files and create vector embeddings | |
if EXTRACTED_FILES: | |
# Prepare documents for processing | |
all_texts = [] | |
for filename, files in EXTRACTED_FILES.items(): | |
for file_path, content in files.items(): | |
all_texts.append( | |
{"page_content": content, "metadata": {"source": file_path}} | |
) | |
# Create document objects | |
from langchain_core.documents import Document | |
documents = [ | |
Document(page_content=item["page_content"], metadata=item["metadata"]) | |
for item in all_texts | |
] | |
# Split the documents into chunks | |
chunks = text_splitter.split_documents(documents) | |
# Create the vector store | |
VECTORSTORE = InMemoryVectorStore.from_documents( | |
documents=chunks, | |
embedding=embed_model, | |
) | |
# Create the retriever | |
retriever = VECTORSTORE.as_retriever() | |
# Create the RAG chain | |
RAG_CHAIN = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| rag_prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Add a confirmation message | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", | |
content="Documents processed and indexed. You can now ask questions about the content.", | |
) | |
) | |
return chatbot | |
def user(text_prompt: str, chatbot: List[gr.ChatMessage]): | |
""" | |
Append a new user text message to the chat history. | |
Parameters: | |
text_prompt (str): The input text provided by the user. | |
chatbot (List[gr.ChatMessage]): The existing conversation history. | |
Returns: | |
Tuple[str, List[gr.ChatMessage]]: A tuple of an empty string (clearing the prompt) | |
and the updated conversation history. | |
""" | |
if text_prompt: | |
chatbot.append(gr.ChatMessage(role="user", content=text_prompt)) | |
return "", chatbot | |
def get_message_content(msg): | |
""" | |
Retrieve the content of a message that can be either a dictionary or a gr.ChatMessage. | |
Parameters: | |
msg (Union[dict, gr.ChatMessage]): The message object. | |
Returns: | |
str: The textual content of the message. | |
""" | |
if isinstance(msg, dict): | |
return msg.get("content", "") | |
return msg.content | |
def process_query(chatbot: List[Union[dict, gr.ChatMessage]]): | |
""" | |
Process the user's query using the RAG pipeline. | |
Parameters: | |
chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Returns: | |
List[Union[dict, gr.ChatMessage]]: The updated conversation history with the response. | |
""" | |
global RAG_CHAIN | |
if len(chatbot) == 0: | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", | |
content="Please enter a question or upload documents to start the conversation.", | |
) | |
) | |
return chatbot | |
# Get the last user message as the prompt | |
user_messages = [ | |
msg | |
for msg in chatbot | |
if (isinstance(msg, dict) and msg.get("role") == "user") | |
or (hasattr(msg, "role") and msg.role == "user") | |
] | |
if not user_messages: | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", | |
content="Please enter a question to start the conversation.", | |
) | |
) | |
return chatbot | |
last_user_msg = user_messages[-1] | |
prompt = get_message_content(last_user_msg) | |
# Skip if the last message was about uploading a file | |
if ( | |
"π¦ ZIP file uploaded:" in prompt | |
or "π File uploaded:" in prompt | |
or "π Multiple files uploaded" in prompt | |
): | |
return chatbot | |
# Check if RAG chain is available | |
if RAG_CHAIN is None: | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", | |
content="Please upload documents first to enable question answering.", | |
) | |
) | |
return chatbot | |
# Append a placeholder for the assistant's response | |
chatbot.append(gr.ChatMessage(role="assistant", content="Thinking...")) | |
try: | |
# Process the query through the RAG chain | |
response = RAG_CHAIN.invoke(prompt) | |
# Update the placeholder with the actual response | |
chatbot[-1].content = response | |
except Exception as e: | |
# Handle any errors | |
chatbot[-1].content = f"Error processing your query: {str(e)}" | |
return chatbot | |
def reset_app(chatbot): | |
""" | |
Reset the app by clearing the chat context and removing any uploaded files. | |
Parameters: | |
chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Returns: | |
List[Union[dict, gr.ChatMessage]]: A fresh conversation history. | |
""" | |
global EXTRACTED_FILES, VECTORSTORE, RAG_CHAIN | |
# Clear the global variables | |
EXTRACTED_FILES = {} | |
VECTORSTORE = None | |
RAG_CHAIN = None | |
# Reset the chatbot with a welcome message | |
return [ | |
gr.ChatMessage( | |
role="assistant", | |
content="App has been reset. You can start a new conversation or upload new documents.", | |
) | |
] | |
# Define the Gradio UI components | |
chatbot_component = gr.Chatbot( | |
label="Llama 4 RAG", | |
type="messages", | |
bubble_full_width=False, | |
avatar_images=AVATAR_IMAGES, | |
scale=2, | |
height=350, | |
) | |
text_prompt_component = gr.Textbox( | |
placeholder="Ask a question about your documents...", | |
show_label=False, | |
autofocus=True, | |
scale=28, | |
) | |
upload_files_button_component = gr.UploadButton( | |
label="Upload", | |
file_count="multiple", | |
file_types=[".zip", ".docx"] + TEXT_EXTENSIONS, | |
scale=1, | |
min_width=80, | |
) | |
send_button_component = gr.Button( | |
value="Send", variant="primary", scale=1, min_width=80 | |
) | |
reset_button_component = gr.Button(value="Reset", variant="stop", scale=1, min_width=80) | |
# Define input lists for button chaining | |
user_inputs = [text_prompt_component, chatbot_component] | |
with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
gr.HTML(TITLE) | |
with gr.Column(): | |
chatbot_component.render() | |
with gr.Row(equal_height=True): | |
text_prompt_component.render() | |
send_button_component.render() | |
upload_files_button_component.render() | |
reset_button_component.render() | |
# When the Send button is clicked, first process the user text then process the query | |
send_button_component.click( | |
fn=user, | |
inputs=user_inputs, | |
outputs=[text_prompt_component, chatbot_component], | |
queue=False, | |
).then( | |
fn=process_query, | |
inputs=[chatbot_component], | |
outputs=[chatbot_component], | |
api_name="process_query", | |
) | |
# Allow submission using the Enter key | |
text_prompt_component.submit( | |
fn=user, | |
inputs=user_inputs, | |
outputs=[text_prompt_component, chatbot_component], | |
queue=False, | |
).then( | |
fn=process_query, | |
inputs=[chatbot_component], | |
outputs=[chatbot_component], | |
api_name="process_query_submit", | |
) | |
# Handle file uploads | |
upload_files_button_component.upload( | |
fn=upload_files, | |
inputs=[upload_files_button_component, chatbot_component], | |
outputs=[chatbot_component], | |
queue=False, | |
) | |
# Handle Reset button clicks | |
reset_button_component.click( | |
fn=reset_app, | |
inputs=[chatbot_component], | |
outputs=[chatbot_component], | |
queue=False, | |
) | |
# Launch the demo interface | |
demo.queue().launch() | |