import gradio as gr import os import json import time import subprocess import threading import uuid from pathlib import Path from huggingface_hub import InferenceClient, HfFolder """ Shedify app - Using fine-tuned Llama 3.3 49B for document assistance """ # Model settings DEFAULT_MODEL = "Borislav18/Shedify" # Your Hugging Face username/model name LOCAL_MODEL = os.environ.get("LOCAL_MODEL", None) # Set this if testing locally # Get Hugging Face token HF_TOKEN = os.environ.get("HF_TOKEN", None) # App title and description title = "Shedify - Document Assistant powered by Llama 3.3" description = """ This app uses a fine-tuned version of Llama 3.3 49B model trained on your documents. Ask questions about the documents, generate insights, or request summaries! """ # Initialize inference client with your model client = InferenceClient( DEFAULT_MODEL, token=HF_TOKEN, ) # Training status tracking class TrainingState: def __init__(self): self.status = "idle" # idle, running, success, failed self.progress = 0.0 # 0.0 to 1.0 self.message = "" self.id = str(uuid.uuid4())[:8] # Generate a unique ID for this session # Check if state file exists and load it self.state_file = Path("training_state.json") self.load_state() def load_state(self): """Load state from file if it exists""" if self.state_file.exists(): try: with open(self.state_file, "r") as f: state = json.load(f) self.status = state.get("status", "idle") self.progress = state.get("progress", 0.0) self.message = state.get("message", "") self.id = state.get("id", self.id) except Exception as e: print(f"Error loading state: {e}") def save_state(self): """Save current state to file""" try: with open(self.state_file, "w") as f: json.dump({ "status": self.status, "progress": self.progress, "message": self.message, "id": self.id }, f) except Exception as e: print(f"Error saving state: {e}") def update(self, status=None, progress=None, message=None): """Update state and save it""" if status is not None: self.status = status if progress is not None: self.progress = progress if message is not None: self.message = message self.save_state() return self.status, self.progress, self.message # Initialize the training state training_state = TrainingState() def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): messages = [{"role": "system", "content": system_message}] # Format history to match chat completion format for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) response = "" # Use streaming to get real-time responses for message in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = message.choices[0].delta.content response += token yield response def run_training_process(pdf_dir, output_name, progress_callback): """Run the PDF processing and fine-tuning process""" try: # Create processed_data directory if it doesn't exist os.makedirs("processed_data", exist_ok=True) # Update state progress_callback("running", 0.05, "Processing PDFs...") # Process PDFs pdf_process = subprocess.run( ["python", "pdf_processor.py", "--pdf_dir", pdf_dir, "--output_dir", "processed_data"], capture_output=True, text=True ) if pdf_process.returncode != 0: progress_callback("failed", 0.0, f"PDF processing failed: {pdf_process.stderr}") return False # Update state progress_callback("running", 0.3, "PDFs processed. Starting fine-tuning...") # Get Hugging Face token hf_token = HF_TOKEN or HfFolder.get_token() if not hf_token: progress_callback("failed", 0.0, "No Hugging Face token found. Please set the HF_TOKEN environment variable.") return False # Run fine-tuning finetune_process = subprocess.run( [ "python", "finetune_llama3.py", "--dataset_path", "processed_data/training_data", "--hub_model_id", f"Borislav18/{output_name}", "--epochs", "1", # Starting with 1 epoch for quicker feedback "--gradient_accumulation_steps", "4" ], env={**os.environ, "HF_TOKEN": hf_token}, capture_output=True, text=True ) if finetune_process.returncode != 0: progress_callback("failed", 0.0, f"Fine-tuning failed: {finetune_process.stderr}") return False # Update state progress_callback("success", 1.0, f"Training complete! Model pushed to Hugging Face as Borislav18/{output_name}") return True except Exception as e: progress_callback("failed", 0.0, f"Training process failed with error: {str(e)}") return False def training_thread(pdf_dir, output_name): """Background thread for running training""" def progress_callback(status, progress, message): training_state.update(status, progress, message) # Simulate progress updates for UI feedback progress_callback("running", 0.01, "Starting training process...") # Run the actual training process run_training_process(pdf_dir, output_name, progress_callback) def start_training(pdf_dir, output_name): """Start the training process in a background thread""" if not pdf_dir or not output_name: return "Please provide both a PDF directory and output model name", 0.0, "idle" # Check if already running if training_state.status == "running": return f"Training already in progress: {training_state.message}", training_state.progress, training_state.status # Start background thread thread = threading.Thread( target=training_thread, args=(pdf_dir, output_name), daemon=True ) thread.start() return "Training started...", 0.0, "running" def get_training_status(): """Get the current training status for UI updates""" return training_state.message, training_state.progress, training_state.status # Create the main application with gr.Blocks(title="Shedify - Document Assistant") as demo: with gr.Row(): with gr.Column(scale=2): gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Column(scale=1): # Training controls with gr.Group(visible=True): gr.Markdown("## Train New Model") pdf_dir = gr.Textbox(label="PDF Directory", placeholder="Path to directory containing PDFs") output_name = gr.Textbox(label="Model Name", placeholder="Name for your fine-tuned model", value="Shedify-v1") train_btn = gr.Button("Start Training") training_message = gr.Textbox(label="Training Status", interactive=False) training_progress = gr.Slider( minimum=0, maximum=1, value=0, label="Progress", interactive=False ) training_status = gr.Textbox(visible=False) # Chat interface chatbot = gr.ChatInterface( fn=respond, additional_inputs=[ gr.Textbox( value="You are an AI assistant trained on specific documents. Answer questions based only on information from these documents. If you don't know the answer from the documents, say so clearly.", label="System message" ), gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)", ), ], examples=[ ["Summarize the key points from all documents you were trained on."], ["What are the main themes discussed in the documents?"], ["Extract the most important concepts mentioned in the documents."], ["Explain the relationship between the different topics in the documents."], ["What recommendations or conclusions can be drawn from the documents?"], ] ) # Set up event handlers train_btn.click( fn=start_training, inputs=[pdf_dir, output_name], outputs=[training_message, training_progress, training_status] ) # Setup periodic status checking demo.load(get_training_status, outputs=[training_message, training_progress, training_status]) def update_ui(message, progress, status): is_running = status == "running" color = { "idle": "gray", "running": "blue", "success": "green", "failed": "red" }.get(status, "gray") message_with_color = f"{message}" return message_with_color, progress, train_btn.update(interactive=not is_running) training_status.change( fn=update_ui, inputs=[training_message, training_progress, training_status], outputs=[training_message, training_progress, train_btn] ) # Set interval to update the UI every few seconds demo.add_event_handler("load", None, None, None, None, interval=5.0, inputs=None, outputs=[training_message, training_progress, training_status], _js=None, fn=get_training_status) if __name__ == "__main__": demo.launch()