Borislav18 commited on
Commit
c3dfa5f
·
1 Parent(s): 6ccc1d7

Update space

Browse files
Files changed (2) hide show
  1. app.py +249 -21
  2. requirements.txt +11 -1
app.py CHANGED
@@ -1,11 +1,88 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def respond(
11
  message,
@@ -17,6 +94,7 @@ def respond(
17
  ):
18
  messages = [{"role": "system", "content": system_message}]
19
 
 
20
  for val in history:
21
  if val[0]:
22
  messages.append({"role": "user", "content": val[0]})
@@ -27,6 +105,7 @@ def respond(
27
 
28
  response = ""
29
 
 
30
  for message in client.chat_completion(
31
  messages,
32
  max_tokens=max_tokens,
@@ -39,26 +118,175 @@ def respond(
39
  response += token
40
  yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
  demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import json
4
+ import time
5
+ import subprocess
6
+ import threading
7
+ import uuid
8
+ from pathlib import Path
9
+ from huggingface_hub import InferenceClient, HfFolder
10
 
11
  """
12
+ Shedify app - Using fine-tuned Llama 3.3 49B for document assistance
13
  """
 
14
 
15
+ # Model settings
16
+ DEFAULT_MODEL = "Borislav18/Shedify" # Your Hugging Face username/model name
17
+ LOCAL_MODEL = os.environ.get("LOCAL_MODEL", None) # Set this if testing locally
18
+
19
+ # Get Hugging Face token
20
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
21
+
22
+ # App title and description
23
+ title = "Shedify - Document Assistant powered by Llama 3.3"
24
+ description = """
25
+ This app uses a fine-tuned version of Llama 3.3 49B model trained on your documents.
26
+ Ask questions about the documents, generate insights, or request summaries!
27
+ """
28
+
29
+ # Initialize inference client with your model
30
+ client = InferenceClient(
31
+ DEFAULT_MODEL,
32
+ token=HF_TOKEN,
33
+ )
34
+
35
+ # Training status tracking
36
+ class TrainingState:
37
+ def __init__(self):
38
+ self.status = "idle" # idle, running, success, failed
39
+ self.progress = 0.0 # 0.0 to 1.0
40
+ self.message = ""
41
+ self.id = str(uuid.uuid4())[:8] # Generate a unique ID for this session
42
+
43
+ # Check if state file exists and load it
44
+ self.state_file = Path("training_state.json")
45
+ self.load_state()
46
+
47
+ def load_state(self):
48
+ """Load state from file if it exists"""
49
+ if self.state_file.exists():
50
+ try:
51
+ with open(self.state_file, "r") as f:
52
+ state = json.load(f)
53
+ self.status = state.get("status", "idle")
54
+ self.progress = state.get("progress", 0.0)
55
+ self.message = state.get("message", "")
56
+ self.id = state.get("id", self.id)
57
+ except Exception as e:
58
+ print(f"Error loading state: {e}")
59
+
60
+ def save_state(self):
61
+ """Save current state to file"""
62
+ try:
63
+ with open(self.state_file, "w") as f:
64
+ json.dump({
65
+ "status": self.status,
66
+ "progress": self.progress,
67
+ "message": self.message,
68
+ "id": self.id
69
+ }, f)
70
+ except Exception as e:
71
+ print(f"Error saving state: {e}")
72
+
73
+ def update(self, status=None, progress=None, message=None):
74
+ """Update state and save it"""
75
+ if status is not None:
76
+ self.status = status
77
+ if progress is not None:
78
+ self.progress = progress
79
+ if message is not None:
80
+ self.message = message
81
+ self.save_state()
82
+ return self.status, self.progress, self.message
83
+
84
+ # Initialize the training state
85
+ training_state = TrainingState()
86
 
87
  def respond(
88
  message,
 
94
  ):
95
  messages = [{"role": "system", "content": system_message}]
96
 
97
+ # Format history to match chat completion format
98
  for val in history:
99
  if val[0]:
100
  messages.append({"role": "user", "content": val[0]})
 
105
 
106
  response = ""
107
 
108
+ # Use streaming to get real-time responses
109
  for message in client.chat_completion(
110
  messages,
111
  max_tokens=max_tokens,
 
118
  response += token
119
  yield response
120
 
121
+ def run_training_process(pdf_dir, output_name, progress_callback):
122
+ """Run the PDF processing and fine-tuning process"""
123
+ try:
124
+ # Create processed_data directory if it doesn't exist
125
+ os.makedirs("processed_data", exist_ok=True)
126
+
127
+ # Update state
128
+ progress_callback("running", 0.05, "Processing PDFs...")
129
+
130
+ # Process PDFs
131
+ pdf_process = subprocess.run(
132
+ ["python", "pdf_processor.py", "--pdf_dir", pdf_dir, "--output_dir", "processed_data"],
133
+ capture_output=True,
134
+ text=True
135
+ )
136
+
137
+ if pdf_process.returncode != 0:
138
+ progress_callback("failed", 0.0, f"PDF processing failed: {pdf_process.stderr}")
139
+ return False
140
+
141
+ # Update state
142
+ progress_callback("running", 0.3, "PDFs processed. Starting fine-tuning...")
143
+
144
+ # Get Hugging Face token
145
+ hf_token = HF_TOKEN or HfFolder.get_token()
146
+ if not hf_token:
147
+ progress_callback("failed", 0.0, "No Hugging Face token found. Please set the HF_TOKEN environment variable.")
148
+ return False
149
+
150
+ # Run fine-tuning
151
+ finetune_process = subprocess.run(
152
+ [
153
+ "python", "finetune_llama3.py",
154
+ "--dataset_path", "processed_data/training_data",
155
+ "--hub_model_id", f"Borislav18/{output_name}",
156
+ "--epochs", "1", # Starting with 1 epoch for quicker feedback
157
+ "--gradient_accumulation_steps", "4"
158
+ ],
159
+ env={**os.environ, "HF_TOKEN": hf_token},
160
+ capture_output=True,
161
+ text=True
162
+ )
163
+
164
+ if finetune_process.returncode != 0:
165
+ progress_callback("failed", 0.0, f"Fine-tuning failed: {finetune_process.stderr}")
166
+ return False
167
+
168
+ # Update state
169
+ progress_callback("success", 1.0, f"Training complete! Model pushed to Hugging Face as Borislav18/{output_name}")
170
+ return True
171
+
172
+ except Exception as e:
173
+ progress_callback("failed", 0.0, f"Training process failed with error: {str(e)}")
174
+ return False
175
+
176
+ def training_thread(pdf_dir, output_name):
177
+ """Background thread for running training"""
178
+ def progress_callback(status, progress, message):
179
+ training_state.update(status, progress, message)
180
+
181
+ # Simulate progress updates for UI feedback
182
+ progress_callback("running", 0.01, "Starting training process...")
183
+
184
+ # Run the actual training process
185
+ run_training_process(pdf_dir, output_name, progress_callback)
186
+
187
+ def start_training(pdf_dir, output_name):
188
+ """Start the training process in a background thread"""
189
+ if not pdf_dir or not output_name:
190
+ return "Please provide both a PDF directory and output model name", 0.0, "idle"
191
+
192
+ # Check if already running
193
+ if training_state.status == "running":
194
+ return f"Training already in progress: {training_state.message}", training_state.progress, training_state.status
195
+
196
+ # Start background thread
197
+ thread = threading.Thread(
198
+ target=training_thread,
199
+ args=(pdf_dir, output_name),
200
+ daemon=True
201
+ )
202
+ thread.start()
203
+
204
+ return "Training started...", 0.0, "running"
205
+
206
+ def get_training_status():
207
+ """Get the current training status for UI updates"""
208
+ return training_state.message, training_state.progress, training_state.status
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ # Create the main application
212
+ with gr.Blocks(title="Shedify - Document Assistant") as demo:
213
+ with gr.Row():
214
+ with gr.Column(scale=2):
215
+ gr.Markdown(f"# {title}")
216
+ gr.Markdown(description)
217
+
218
+ with gr.Column(scale=1):
219
+ # Training controls
220
+ with gr.Group(visible=True):
221
+ gr.Markdown("## Train New Model")
222
+ pdf_dir = gr.Textbox(label="PDF Directory", placeholder="Path to directory containing PDFs")
223
+ output_name = gr.Textbox(label="Model Name", placeholder="Name for your fine-tuned model", value="Shedify-v1")
224
+ train_btn = gr.Button("Start Training")
225
+
226
+ training_message = gr.Textbox(label="Training Status", interactive=False)
227
+ training_progress = gr.Slider(
228
+ minimum=0, maximum=1, value=0,
229
+ label="Progress", interactive=False
230
+ )
231
+ training_status = gr.Textbox(visible=False)
232
+
233
+ # Chat interface
234
+ chatbot = gr.ChatInterface(
235
+ fn=respond,
236
+ additional_inputs=[
237
+ gr.Textbox(
238
+ value="You are an AI assistant trained on specific documents. Answer questions based only on information from these documents. If you don't know the answer from the documents, say so clearly.",
239
+ label="System message"
240
+ ),
241
+ gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
242
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
243
+ gr.Slider(
244
+ minimum=0.1,
245
+ maximum=1.0,
246
+ value=0.9,
247
+ step=0.05,
248
+ label="Top-p (nucleus sampling)",
249
+ ),
250
+ ],
251
+ examples=[
252
+ ["Summarize the key points from all documents you were trained on."],
253
+ ["What are the main themes discussed in the documents?"],
254
+ ["Extract the most important concepts mentioned in the documents."],
255
+ ["Explain the relationship between the different topics in the documents."],
256
+ ["What recommendations or conclusions can be drawn from the documents?"],
257
+ ]
258
+ )
259
+
260
+ # Set up event handlers
261
+ train_btn.click(
262
+ fn=start_training,
263
+ inputs=[pdf_dir, output_name],
264
+ outputs=[training_message, training_progress, training_status]
265
+ )
266
+
267
+ # Setup periodic status checking
268
+ demo.load(get_training_status, outputs=[training_message, training_progress, training_status])
269
+
270
+ def update_ui(message, progress, status):
271
+ is_running = status == "running"
272
+ color = {
273
+ "idle": "gray",
274
+ "running": "blue",
275
+ "success": "green",
276
+ "failed": "red"
277
+ }.get(status, "gray")
278
+
279
+ message_with_color = f"<span style='color: {color}'>{message}</span>"
280
+ return message_with_color, progress, train_btn.update(interactive=not is_running)
281
+
282
+ training_status.change(
283
+ fn=update_ui,
284
+ inputs=[training_message, training_progress, training_status],
285
+ outputs=[training_message, training_progress, train_btn]
286
+ )
287
+
288
+ # Set interval to update the UI every few seconds
289
+ demo.add_event_handler("load", None, None, None, None, interval=5.0, inputs=None, outputs=[training_message, training_progress, training_status], _js=None, fn=get_training_status)
290
 
291
  if __name__ == "__main__":
292
  demo.launch()
requirements.txt CHANGED
@@ -1 +1,11 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub>=0.25.2
2
+ gradio>=5.0.1
3
+ transformers>=4.36.0
4
+ peft>=0.7.0
5
+ datasets>=2.14.0
6
+ accelerate>=0.25.0
7
+ trl>=0.7.1
8
+ bitsandbytes>=0.40.0
9
+ torch>=2.0.0
10
+ PyPDF2>=3.0.0
11
+ tqdm>=4.65.0