Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import json | |
import time | |
import subprocess | |
import threading | |
import uuid | |
from pathlib import Path | |
from huggingface_hub import InferenceClient, HfFolder | |
""" | |
Shedify app - Using fine-tuned Llama 3.3 49B for document assistance | |
""" | |
# Model settings | |
DEFAULT_MODEL = "Borislav18/Shedify" # Your Hugging Face username/model name | |
LOCAL_MODEL = os.environ.get("LOCAL_MODEL", None) # Set this if testing locally | |
# Get Hugging Face token | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# App title and description | |
title = "Shedify - Document Assistant powered by Llama 3.3" | |
description = """ | |
This app uses a fine-tuned version of Llama 3.3 49B model trained on your documents. | |
Ask questions about the documents, generate insights, or request summaries! | |
""" | |
# Initialize inference client with your model | |
client = InferenceClient( | |
DEFAULT_MODEL, | |
token=HF_TOKEN, | |
) | |
# Training status tracking | |
class TrainingState: | |
def __init__(self): | |
self.status = "idle" # idle, running, success, failed | |
self.progress = 0.0 # 0.0 to 1.0 | |
self.message = "" | |
self.id = str(uuid.uuid4())[:8] # Generate a unique ID for this session | |
# Check if state file exists and load it | |
self.state_file = Path("training_state.json") | |
self.load_state() | |
def load_state(self): | |
"""Load state from file if it exists""" | |
if self.state_file.exists(): | |
try: | |
with open(self.state_file, "r") as f: | |
state = json.load(f) | |
self.status = state.get("status", "idle") | |
self.progress = state.get("progress", 0.0) | |
self.message = state.get("message", "") | |
self.id = state.get("id", self.id) | |
except Exception as e: | |
print(f"Error loading state: {e}") | |
def save_state(self): | |
"""Save current state to file""" | |
try: | |
with open(self.state_file, "w") as f: | |
json.dump({ | |
"status": self.status, | |
"progress": self.progress, | |
"message": self.message, | |
"id": self.id | |
}, f) | |
except Exception as e: | |
print(f"Error saving state: {e}") | |
def update(self, status=None, progress=None, message=None): | |
"""Update state and save it""" | |
if status is not None: | |
self.status = status | |
if progress is not None: | |
self.progress = progress | |
if message is not None: | |
self.message = message | |
self.save_state() | |
return self.status, self.progress, self.message | |
# Initialize the training state | |
training_state = TrainingState() | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": system_message}] | |
# Format history to match chat completion format | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
# Use streaming to get real-time responses | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
def run_training_process(pdf_dir, output_name, progress_callback): | |
"""Run the PDF processing and fine-tuning process""" | |
try: | |
# Create processed_data directory if it doesn't exist | |
os.makedirs("processed_data", exist_ok=True) | |
# Update state | |
progress_callback("running", 0.05, "Processing PDFs...") | |
# Process PDFs | |
pdf_process = subprocess.run( | |
["python", "pdf_processor.py", "--pdf_dir", pdf_dir, "--output_dir", "processed_data"], | |
capture_output=True, | |
text=True | |
) | |
if pdf_process.returncode != 0: | |
progress_callback("failed", 0.0, f"PDF processing failed: {pdf_process.stderr}") | |
return False | |
# Update state | |
progress_callback("running", 0.3, "PDFs processed. Starting fine-tuning...") | |
# Get Hugging Face token | |
hf_token = HF_TOKEN or HfFolder.get_token() | |
if not hf_token: | |
progress_callback("failed", 0.0, "No Hugging Face token found. Please set the HF_TOKEN environment variable.") | |
return False | |
# Run fine-tuning | |
finetune_process = subprocess.run( | |
[ | |
"python", "finetune_llama3.py", | |
"--dataset_path", "processed_data/training_data", | |
"--hub_model_id", f"Borislav18/{output_name}", | |
"--epochs", "1", # Starting with 1 epoch for quicker feedback | |
"--gradient_accumulation_steps", "4" | |
], | |
env={**os.environ, "HF_TOKEN": hf_token}, | |
capture_output=True, | |
text=True | |
) | |
if finetune_process.returncode != 0: | |
progress_callback("failed", 0.0, f"Fine-tuning failed: {finetune_process.stderr}") | |
return False | |
# Update state | |
progress_callback("success", 1.0, f"Training complete! Model pushed to Hugging Face as Borislav18/{output_name}") | |
return True | |
except Exception as e: | |
progress_callback("failed", 0.0, f"Training process failed with error: {str(e)}") | |
return False | |
def training_thread(pdf_dir, output_name): | |
"""Background thread for running training""" | |
def progress_callback(status, progress, message): | |
training_state.update(status, progress, message) | |
# Simulate progress updates for UI feedback | |
progress_callback("running", 0.01, "Starting training process...") | |
# Run the actual training process | |
run_training_process(pdf_dir, output_name, progress_callback) | |
def start_training(pdf_dir, output_name): | |
"""Start the training process in a background thread""" | |
if not pdf_dir or not output_name: | |
return "Please provide both a PDF directory and output model name", 0.0, "idle" | |
# Check if already running | |
if training_state.status == "running": | |
return f"Training already in progress: {training_state.message}", training_state.progress, training_state.status | |
# Start background thread | |
thread = threading.Thread( | |
target=training_thread, | |
args=(pdf_dir, output_name), | |
daemon=True | |
) | |
thread.start() | |
return "Training started...", 0.0, "running" | |
def get_training_status(): | |
"""Get the current training status for UI updates""" | |
return training_state.message, training_state.progress, training_state.status | |
# Create the main application | |
with gr.Blocks(title="Shedify - Document Assistant") as demo: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
with gr.Column(scale=1): | |
# Training controls | |
with gr.Group(visible=True): | |
gr.Markdown("## Train New Model") | |
pdf_dir = gr.Textbox(label="PDF Directory", placeholder="Path to directory containing PDFs") | |
output_name = gr.Textbox(label="Model Name", placeholder="Name for your fine-tuned model", value="Shedify-v1") | |
train_btn = gr.Button("Start Training") | |
training_message = gr.Textbox(label="Training Status", interactive=False) | |
training_progress = gr.Slider( | |
minimum=0, maximum=1, value=0, | |
label="Progress", interactive=False | |
) | |
training_status = gr.Textbox(visible=False) | |
# Chat interface | |
chatbot = gr.ChatInterface( | |
fn=respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are an AI assistant trained on specific documents. Answer questions based only on information from these documents. If you don't know the answer from the documents, say so clearly.", | |
label="System message" | |
), | |
gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
examples=[ | |
["Summarize the key points from all documents you were trained on."], | |
["What are the main themes discussed in the documents?"], | |
["Extract the most important concepts mentioned in the documents."], | |
["Explain the relationship between the different topics in the documents."], | |
["What recommendations or conclusions can be drawn from the documents?"], | |
] | |
) | |
# Set up event handlers | |
train_btn.click( | |
fn=start_training, | |
inputs=[pdf_dir, output_name], | |
outputs=[training_message, training_progress, training_status] | |
) | |
# Setup periodic status checking | |
demo.load(get_training_status, outputs=[training_message, training_progress, training_status]) | |
def update_ui(message, progress, status): | |
is_running = status == "running" | |
color = { | |
"idle": "gray", | |
"running": "blue", | |
"success": "green", | |
"failed": "red" | |
}.get(status, "gray") | |
message_with_color = f"<span style='color: {color}'>{message}</span>" | |
return message_with_color, progress, train_btn.update(interactive=not is_running) | |
training_status.change( | |
fn=update_ui, | |
inputs=[training_message, training_progress, training_status], | |
outputs=[training_message, training_progress, train_btn] | |
) | |
# Set interval to update the UI every few seconds | |
demo.add_event_handler("load", None, None, None, None, interval=5.0, inputs=None, outputs=[training_message, training_progress, training_status], _js=None, fn=get_training_status) | |
if __name__ == "__main__": | |
demo.launch() | |