Natwar commited on
Commit
4275450
·
verified ·
1 Parent(s): 9d4ad57

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +891 -0
app.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ import warnings
5
+ import logging
6
+ from typing import List, Dict, Any, Optional
7
+ import tempfile
8
+ import re
9
+ import time
10
+ import gc
11
+ import spaces
12
+
13
+ # Set up logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
17
+ handlers=[
18
+ logging.FileHandler("debug.log"),
19
+ logging.StreamHandler()
20
+ ]
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Suppress warnings
25
+ warnings.filterwarnings("ignore")
26
+
27
+ def install_package(package: str, version: Optional[str] = None) -> None:
28
+ """Install a Python package if not already installed"""
29
+ package_spec = f"{package}=={version}" if version else package
30
+ try:
31
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", package_spec])
32
+ print(f"Successfully installed {package_spec}")
33
+ except subprocess.CalledProcessError as e:
34
+ print(f"Failed to install {package_spec}: {e}")
35
+ raise
36
+
37
+ # Required packages - install these before importing
38
+ required_packages = {
39
+ "torch": None,
40
+ "gradio": "3.10.1",
41
+ "transformers": None,
42
+ "peft": None,
43
+ "bitsandbytes": None,
44
+ "PyPDF2": None,
45
+ "python-docx": None,
46
+ "accelerate": None,
47
+ "sentencepiece": None,
48
+ }
49
+
50
+ # Install required packages BEFORE importing them
51
+ for package, version in required_packages.items():
52
+ try:
53
+ __import__(package)
54
+ print(f"{package} is already installed.")
55
+ except ImportError:
56
+ print(f"Installing {package}...")
57
+ install_package(package, version)
58
+
59
+ # Now we can safely import all required modules
60
+ import torch
61
+ import transformers
62
+ import gradio as gr
63
+ from transformers import (
64
+ AutoTokenizer, AutoModelForCausalLM,
65
+ TrainingArguments, Trainer, TrainerCallback,
66
+ BitsAndBytesConfig
67
+ )
68
+ from peft import (
69
+ LoraConfig,
70
+ prepare_model_for_kbit_training,
71
+ get_peft_model
72
+ )
73
+ import PyPDF2
74
+ import docx
75
+ import numpy as np
76
+ from tqdm import tqdm
77
+ from torch.utils.data import Dataset as TorchDataset
78
+
79
+ # Suppress transformers warnings
80
+ transformers.logging.set_verbosity_error()
81
+
82
+ # Check GPU availability
83
+ if torch.cuda.is_available():
84
+ DEVICE = "cuda"
85
+ print(f"GPU found: {torch.cuda.get_device_name(0)}")
86
+ print(f"CUDA version: {torch.version.cuda}")
87
+ else:
88
+ DEVICE = "cpu"
89
+ print("No GPU found, using CPU. Fine-tuning will be much slower.")
90
+ print("For better performance, use Google Colab with GPU runtime (Runtime > Change runtime type > GPU)")
91
+
92
+ # Constants specific to Phi-2
93
+ MODEL_KEY = "microsoft/phi-2"
94
+ MAX_SEQ_LEN = 512 # Reduced from 1024 for much lighter memory usage
95
+ # FIX: Updated target modules for Phi-2
96
+ LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "dense"] # Correct modules for Phi-2
97
+
98
+ # Initialize model and tokenizer
99
+ model = None
100
+ tokenizer = None
101
+ fine_tuned_model = None
102
+ document_text = "" # Store document content for context
103
+
104
+ def load_base_model() -> str:
105
+ """Load Phi-2 with 8-bit quantization instead of 4-bit for faster training"""
106
+ global model, tokenizer
107
+
108
+ if torch.cuda.is_available():
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+
112
+ try:
113
+ # Use 8-bit quantization (faster to train than 4-bit)
114
+ if DEVICE == "cuda":
115
+ bnb_config = BitsAndBytesConfig(
116
+ load_in_8bit=True,
117
+ llm_int8_threshold=6.0,
118
+ llm_int8_has_fp16_weight=False
119
+ )
120
+ else:
121
+ bnb_config = None
122
+
123
+ # Load tokenizer with Phi-2 specific settings
124
+ print("Loading Phi-2 tokenizer...")
125
+ tokenizer = AutoTokenizer.from_pretrained(
126
+ MODEL_KEY,
127
+ trust_remote_code=True,
128
+ padding_side="right"
129
+ )
130
+
131
+ # Ensure pad token is properly set
132
+ if tokenizer.pad_token is None:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
+
135
+ # Load model with Phi-2 specific configuration
136
+ print("Loading Phi-2 model... (this may take a few minutes)")
137
+ if DEVICE == "cuda":
138
+ model = AutoModelForCausalLM.from_pretrained(
139
+ MODEL_KEY,
140
+ quantization_config=bnb_config,
141
+ device_map="auto",
142
+ torch_dtype=torch.float16,
143
+ trust_remote_code=True,
144
+ low_cpu_mem_usage=True
145
+ )
146
+ else:
147
+ model = AutoModelForCausalLM.from_pretrained(
148
+ MODEL_KEY,
149
+ torch_dtype=torch.float32,
150
+ trust_remote_code=True,
151
+ low_cpu_mem_usage=True
152
+ ).to(DEVICE)
153
+
154
+ print("Phi-2 (2.7B) model loaded successfully!")
155
+ return "Phi-2 (2.7B) model loaded successfully! Ready to process documents."
156
+
157
+ except Exception as e:
158
+ error_msg = f"Error loading model: {str(e)}"
159
+ print(error_msg)
160
+ return error_msg
161
+
162
+ def phi2_prompt_template(context: str, question: str) -> str:
163
+ """
164
+ Create a prompt optimized for Phi-2
165
+ Phi-2 responds well to clear instruction formatting
166
+ """
167
+ return f"""Instruction: Answer the question accurately based on the context provided.
168
+ Context: {context}
169
+ Question: {question}
170
+ Answer:"""
171
+
172
+ def process_pdf(file_path: str) -> str:
173
+ """Extract text from PDF file"""
174
+ text = ""
175
+ try:
176
+ with open(file_path, 'rb') as file:
177
+ pdf_reader = PyPDF2.PdfReader(file)
178
+ total_pages = len(pdf_reader.pages)
179
+ # Process at most 30 pages to avoid memory issues
180
+ pages_to_process = min(total_pages, 30)
181
+ for i in range(pages_to_process):
182
+ page = pdf_reader.pages[i]
183
+ page_text = page.extract_text() or ""
184
+ text += page_text + "\n"
185
+
186
+ if total_pages > pages_to_process:
187
+ text += f"\n[Note: Only the first {pages_to_process} pages were processed due to size limitations.]"
188
+ except Exception as e:
189
+ print(f"Error processing PDF: {str(e)}")
190
+ return text
191
+
192
+ def process_docx(file_path: str) -> str:
193
+ """Extract text from DOCX file"""
194
+ try:
195
+ doc = docx.Document(file_path)
196
+ text = "\n".join([para.text for para in doc.paragraphs])
197
+ return text
198
+ except Exception as e:
199
+ print(f"Error processing DOCX: {str(e)}")
200
+ return ""
201
+
202
+ def process_txt(file_path: str) -> str:
203
+ """Extract text from TXT file"""
204
+ try:
205
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
206
+ text = file.read()
207
+ return text
208
+ except Exception as e:
209
+ print(f"Error processing TXT: {str(e)}")
210
+ return ""
211
+
212
+ def preprocess_text(text: str) -> str:
213
+ """Clean and preprocess text"""
214
+ if not text:
215
+ return ""
216
+ # Remove extra whitespace
217
+ text = re.sub(r'\s+', ' ', text)
218
+ # Remove special characters that may cause issues
219
+ text = re.sub(r'[^\w\s.,;:!?\'\"()-]', '', text)
220
+ return text.strip()
221
+
222
+ def get_semantic_chunks(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
223
+ """More efficient semantic chunking"""
224
+ if not text:
225
+ return []
226
+
227
+ # Simple sentence splitting for speed
228
+ sentences = re.split(r'(?<=[.!?])\s+', text)
229
+ chunks = []
230
+ current_chunk = []
231
+ current_length = 0
232
+
233
+ for sentence in sentences:
234
+ words = sentence.split()
235
+ if current_length + len(words) <= chunk_size:
236
+ current_chunk.append(sentence)
237
+ current_length += len(words)
238
+ else:
239
+ if current_chunk:
240
+ chunks.append(' '.join(current_chunk))
241
+ current_chunk = [sentence]
242
+ current_length = len(words)
243
+
244
+ if current_chunk:
245
+ chunks.append(' '.join(current_chunk))
246
+
247
+ # Limit to just 5 chunks for much faster processing
248
+ if len(chunks) > 5:
249
+ indices = np.linspace(0, len(chunks)-1, 5, dtype=int)
250
+ chunks = [chunks[i] for i in indices]
251
+
252
+ return chunks
253
+
254
+ def create_qa_dataset(document_chunks: List[str]) -> List[Dict[str, str]]:
255
+ """Create comprehensive QA pairs from document chunks for better fine-tuning"""
256
+ qa_pairs = []
257
+
258
+ # Document-level questions
259
+ full_text = " ".join(document_chunks[:5]) # Use beginning of document for overview
260
+ qa_pairs.append({
261
+ "question": "What is this document about?",
262
+ "context": full_text,
263
+ "answer": "Based on my analysis, this document discusses..." # Empty template for model to learn
264
+ })
265
+
266
+ qa_pairs.append({
267
+ "question": "Summarize the key points of this document.",
268
+ "context": full_text,
269
+ "answer": "The key points of this document are..."
270
+ })
271
+
272
+ # Process each chunk for specific QA pairs
273
+ for i, chunk in enumerate(document_chunks):
274
+ if not chunk or len(chunk) < 100: # Skip very short chunks
275
+ continue
276
+
277
+ # Context-specific questions
278
+ chunk_index = i + 1 # 1-indexed for readability
279
+
280
+ # Basic factual questions about chunk content
281
+ qa_pairs.append({
282
+ "question": f"What information is contained in section {chunk_index}?",
283
+ "context": chunk,
284
+ "answer": f"Section {chunk_index} contains information about..."
285
+ })
286
+
287
+ # Entity-based questions - find names, organizations, technical terms
288
+ entities = set(re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', chunk))
289
+ technical_terms = set(re.findall(r'\b[A-Za-z]+-?[A-Za-z]+\b', chunk))
290
+
291
+ # Filter to meaningful entities (longer than 3 chars)
292
+ entities = [e for e in entities if len(e) > 3][:2] # Limit to 2 entity questions per chunk
293
+
294
+ for entity in entities:
295
+ qa_pairs.append({
296
+ "question": f"What does the document say about {entity}?",
297
+ "context": chunk,
298
+ "answer": f"Regarding {entity}, the document states that..."
299
+ })
300
+
301
+ # Specific content questions
302
+ sentences = re.split(r'(?<=[.!?])\s+', chunk)
303
+ key_sentences = [s for s in sentences if len(s.split()) > 8][:2] # Focus on substantive sentences
304
+
305
+ for sentence in key_sentences:
306
+ # Create question from sentence by identifying subject
307
+ subject_match = re.search(r'^(The|A|An|This|These|Those|Some|Any|Many|Few|All|Most)?\s*([A-Za-z\s]+?)\s+(is|are|was|were|has|have|had|can|could|will|would|may|might)', sentence, re.IGNORECASE)
308
+ if subject_match:
309
+ subject = subject_match.group(2).strip()
310
+ if len(subject) > 2:
311
+ qa_pairs.append({
312
+ "question": f"What information is provided about {subject}?",
313
+ "context": chunk,
314
+ "answer": sentence
315
+ })
316
+
317
+ # Add relationship questions between concepts
318
+ if i < len(document_chunks) - 1:
319
+ next_chunk = document_chunks[i+1]
320
+ qa_pairs.append({
321
+ "question": f"How does the information in section {chunk_index} relate to section {chunk_index+1}?",
322
+ "context": chunk + " " + next_chunk,
323
+ "answer": f"Section {chunk_index} discusses... while section {chunk_index+1} covers... The relationship between them is..."
324
+ })
325
+
326
+ # Limit to 5 examples max for lighter memory usage
327
+ if len(qa_pairs) > 5:
328
+ import random
329
+ random.shuffle(qa_pairs)
330
+ qa_pairs = qa_pairs[:5]
331
+
332
+ return qa_pairs
333
+
334
+ class QADataset(TorchDataset):
335
+ """PyTorch dataset specialized for Phi-2 QA fine-tuning"""
336
+ def __init__(self, qa_pairs: List[Dict[str, str]], tokenizer, max_length: int = MAX_SEQ_LEN):
337
+ self.qa_pairs = qa_pairs
338
+ self.tokenizer = tokenizer
339
+ self.max_length = max_length
340
+
341
+ # Verify dataset structure
342
+ self.validate_dataset()
343
+
344
+ def validate_dataset(self):
345
+ """Verify that the dataset has proper structure"""
346
+ if not self.qa_pairs:
347
+ print("Warning: Empty dataset!")
348
+ return
349
+
350
+ required_keys = ["question", "context", "answer"]
351
+ for i, item in enumerate(self.qa_pairs[:5]): # Check first 5 examples
352
+ missing = [k for k in required_keys if k not in item]
353
+ if missing:
354
+ print(f"Warning: Example {i} missing keys: {missing}")
355
+
356
+ # Check for empty values
357
+ empty = [k for k in required_keys if k in item and not item[k]]
358
+ if empty:
359
+ print(f"Warning: Example {i} has empty values for: {empty}")
360
+
361
+ def __len__(self):
362
+ return len(self.qa_pairs)
363
+
364
+ def __getitem__(self, idx):
365
+ qa_pair = self.qa_pairs[idx]
366
+
367
+ # Format prompt using Phi-2 template
368
+ context = qa_pair['context']
369
+ question = qa_pair['question']
370
+ answer = qa_pair['answer']
371
+
372
+ # Build Phi-2 specific prompt
373
+ prompt = phi2_prompt_template(context, question)
374
+
375
+ # Concatenate prompt and answer
376
+ sequence = f"{prompt} {answer}"
377
+
378
+ try:
379
+ # Tokenize with proper handling
380
+ encoded = self.tokenizer(
381
+ sequence,
382
+ truncation=True,
383
+ max_length=self.max_length,
384
+ padding="max_length",
385
+ return_tensors="pt"
386
+ )
387
+
388
+ # Extract tensors
389
+ input_ids = encoded["input_ids"].squeeze(0)
390
+ attention_mask = encoded["attention_mask"].squeeze(0)
391
+
392
+ # Create labels
393
+ labels = input_ids.clone()
394
+
395
+ # Calculate prompt length accurately
396
+ prompt_encoded = self.tokenizer(prompt, add_special_tokens=False)
397
+ prompt_length = len(prompt_encoded["input_ids"])
398
+
399
+ # Ensure prompt_length doesn't exceed labels length
400
+ prompt_length = min(prompt_length, len(labels))
401
+
402
+ # Set labels for prompt portion to -100 (ignored in loss calculation)
403
+ labels[:prompt_length] = -100
404
+
405
+ return {
406
+ "input_ids": input_ids,
407
+ "attention_mask": attention_mask,
408
+ "labels": labels
409
+ }
410
+
411
+ except Exception as e:
412
+ print(f"Error processing sample {idx}: {e}")
413
+ # Return dummy sample as fallback
414
+ return {
415
+ "input_ids": torch.zeros(self.max_length, dtype=torch.long),
416
+ "attention_mask": torch.zeros(self.max_length, dtype=torch.long),
417
+ "labels": torch.zeros(self.max_length, dtype=torch.long)
418
+ }
419
+
420
+ def clear_gpu_memory():
421
+ """Clear GPU memory to prevent OOM errors"""
422
+ if torch.cuda.is_available():
423
+ torch.cuda.empty_cache()
424
+ gc.collect()
425
+
426
+ class ProgressCallback(TrainerCallback):
427
+ def __init__(self, progress, status_box=None):
428
+ self.progress = progress
429
+ self.status_box = status_box
430
+ self.current_step = 0
431
+ self.total_steps = 0
432
+
433
+ def on_train_begin(self, args, state, control, **kwargs):
434
+ self.total_steps = state.max_steps
435
+
436
+ def on_step_end(self, args, state, control, **kwargs):
437
+ self.current_step = state.global_step
438
+ progress_percent = self.current_step / self.total_steps
439
+ self.progress(0.4 + (0.5 * progress_percent),
440
+ desc=f"Epoch {state.epoch}/{args.num_train_epochs} | Step {self.current_step}/{self.total_steps}")
441
+ if self.status_box:
442
+ self.status_box.update(f"Training in progress: Epoch {state.epoch}/{args.num_train_epochs} | Step {self.current_step}/{self.total_steps}")
443
+
444
+ def create_deepspeed_config():
445
+ """Create DeepSpeed config for faster training"""
446
+ return {
447
+ "fp16": {
448
+ "enabled": True
449
+ },
450
+ "zero_optimization": {
451
+ "stage": 2,
452
+ "offload_optimizer": {
453
+ "device": "cpu",
454
+ "pin_memory": True
455
+ },
456
+ "allgather_partitions": True,
457
+ "allgather_bucket_size": 5e8,
458
+ "reduce_scatter": True,
459
+ "reduce_bucket_size": 5e8,
460
+ "overlap_comm": True,
461
+ "contiguous_gradients": True
462
+ },
463
+ "optimizer": {
464
+ "type": "AdamW",
465
+ "params": {
466
+ "lr": 2e-4,
467
+ "betas": [0.9, 0.999],
468
+ "eps": 1e-8,
469
+ "weight_decay": 0.01
470
+ }
471
+ },
472
+ "scheduler": {
473
+ "type": "WarmupLR",
474
+ "params": {
475
+ "warmup_min_lr": 0,
476
+ "warmup_max_lr": 2e-4,
477
+ "warmup_num_steps": 50
478
+ }
479
+ },
480
+ "train_batch_size": 1,
481
+ "train_micro_batch_size_per_gpu": 1,
482
+ "gradient_accumulation_steps": 1,
483
+ "gradient_clipping": 0.5,
484
+ "steps_per_print": 10
485
+ }
486
+
487
+ def finetune_model(qa_dataset, progress=gr.Progress(), status_box=None):
488
+ """Fine-tune Phi-2 using optimized LoRA parameters"""
489
+ global model, tokenizer, fine_tuned_model
490
+
491
+ if model is None:
492
+ return "Please load the base model first."
493
+
494
+ if len(qa_dataset) == 0:
495
+ return "No training data created. Please check your document."
496
+
497
+ try:
498
+ progress(0.1, desc="Preparing model for fine-tuning...")
499
+ if status_box:
500
+ status_box.update("Preparing model for fine-tuning...")
501
+
502
+ # Clear GPU memory
503
+ clear_gpu_memory()
504
+
505
+ # Prepare model for 8-bit training if using GPU
506
+ if DEVICE == "cuda":
507
+ training_model = prepare_model_for_kbit_training(model)
508
+ else:
509
+ training_model = model
510
+
511
+ # Add this line to fix the gradient error
512
+ training_model.enable_input_require_grads()
513
+
514
+ # Configure LoRA for Phi-2
515
+ peft_config = LoraConfig(
516
+ r=2, # Reduced rank for lighter training
517
+ lora_alpha=4, # Reduced alpha
518
+ lora_dropout=0.05, # Added small dropout for regularization
519
+ bias="none",
520
+ task_type="CAUSAL_LM",
521
+ target_modules=LORA_TARGET_MODULES # Fixed Phi-2 modules
522
+ )
523
+
524
+ # Apply LoRA to model
525
+ lora_model = get_peft_model(training_model, peft_config)
526
+
527
+ # Print trainable parameters
528
+ trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
529
+ all_params = sum(p.numel() for p in lora_model.parameters())
530
+ print(f"Trainable parameters: {trainable_params:,} ({trainable_params/all_params:.2%} of {all_params:,} total)")
531
+
532
+ # Enable gradient checkpointing for memory efficiency
533
+ if hasattr(lora_model, "gradient_checkpointing_enable"):
534
+ lora_model.gradient_checkpointing_enable()
535
+ print("Gradient checkpointing enabled")
536
+
537
+ # Create training arguments optimized for Phi-2
538
+ training_args = TrainingArguments(
539
+ output_dir="./results",
540
+ num_train_epochs=2, # Set to 2 as requested
541
+ per_device_train_batch_size=1,
542
+ gradient_accumulation_steps=1,
543
+ learning_rate=1e-4, # Reduced from 2e-4 for stability
544
+ lr_scheduler_type="constant", # Simplified scheduler
545
+ warmup_ratio=0.05, # Slight increase in warmup
546
+ weight_decay=0.01,
547
+ logging_steps=1,
548
+ max_grad_norm=0.3, # Reduced from 0.5 for better gradient stability
549
+ save_strategy="no",
550
+ report_to="none",
551
+ remove_unused_columns=False,
552
+ fp16=(DEVICE == "cuda"),
553
+ no_cuda=(DEVICE == "cpu"),
554
+ optim="adamw_torch", # Use standard optimizer instead of fused for stability
555
+ gradient_checkpointing=True
556
+ )
557
+
558
+ # Add DeepSpeed if on CUDA
559
+ if DEVICE == "cuda":
560
+ training_args.deepspeed = create_deepspeed_config()
561
+
562
+ # Create data collator that doesn't move tensors to device yet
563
+ def collate_fn(features):
564
+ batch = {}
565
+ for key in features[0].keys():
566
+ if key in ["input_ids", "attention_mask", "labels"]:
567
+ batch[key] = torch.stack([f[key] for f in features])
568
+ return batch
569
+
570
+ progress(0.3, desc="Setting up trainer...")
571
+ if status_box:
572
+ status_box.update("Setting up trainer...")
573
+
574
+ # Create trainer
575
+ trainer = Trainer(
576
+ model=lora_model,
577
+ args=training_args,
578
+ train_dataset=qa_dataset,
579
+ data_collator=collate_fn,
580
+ callbacks=[ProgressCallback(progress, status_box)] # Add both callbacks
581
+ )
582
+
583
+ # Start training
584
+ progress(0.4, desc="Initializing training...")
585
+ if status_box:
586
+ status_box.update("Initializing training...")
587
+ print("Starting training...")
588
+ trainer.train()
589
+
590
+ # Set fine-tuned model
591
+ fine_tuned_model = lora_model
592
+
593
+ # Put model in evaluation mode
594
+ fine_tuned_model.eval()
595
+
596
+ # Clear memory
597
+ clear_gpu_memory()
598
+
599
+ return "Fine-tuning completed successfully! You can now ask questions about your document."
600
+
601
+ except Exception as e:
602
+ error_msg = f"Error during fine-tuning: {str(e)}"
603
+ print(error_msg)
604
+ import traceback
605
+ traceback.print_exc()
606
+
607
+ # Try to clean up memory
608
+ try:
609
+ clear_gpu_memory()
610
+ except:
611
+ pass
612
+
613
+ return error_msg
614
+
615
+ def process_document(file_obj, progress=gr.Progress(), status_box=None):
616
+ """Process uploaded document and prepare dataset for fine-tuning"""
617
+ global model, tokenizer, document_text
618
+
619
+ progress(0, desc="Processing document...")
620
+ if status_box:
621
+ status_box.update("Processing document...")
622
+
623
+ if not file_obj:
624
+ return "Please upload a document first."
625
+
626
+ try:
627
+ # Create temp directory for file
628
+ temp_dir = tempfile.mkdtemp()
629
+
630
+ # Get file name
631
+ file_name = getattr(file_obj, 'name', 'uploaded_file')
632
+ if not isinstance(file_name, str):
633
+ file_name = "uploaded_file.txt" # Default name
634
+
635
+ # Ensure file has extension
636
+ if '.' not in file_name:
637
+ file_name = file_name + '.txt'
638
+
639
+ temp_path = os.path.join(temp_dir, file_name)
640
+
641
+ # Get file content
642
+ if hasattr(file_obj, 'read'):
643
+ file_content = file_obj.read()
644
+ else:
645
+ file_content = file_obj
646
+
647
+ with open(temp_path, 'wb') as f:
648
+ f.write(file_content)
649
+
650
+ # Extract text based on file extension
651
+ file_extension = os.path.splitext(file_name)[1].lower()
652
+
653
+ if file_extension == '.pdf':
654
+ text = process_pdf(temp_path)
655
+ elif file_extension in ['.docx', '.doc']:
656
+ text = process_docx(temp_path)
657
+ elif file_extension == '.txt' or True: # Default to txt for unknown extensions
658
+ text = process_txt(temp_path)
659
+
660
+ # Check if text was extracted
661
+ if not text or len(text) < 50:
662
+ return "Could not extract sufficient text from the document. Please check the file."
663
+
664
+ # Save document text for context window during inference
665
+ document_text = text
666
+
667
+ # Preprocess and chunk the document
668
+ progress(0.3, desc="Preprocessing document...")
669
+ if status_box:
670
+ status_box.update("Preprocessing document...")
671
+ text = preprocess_text(text)
672
+ chunks = get_semantic_chunks(text)
673
+
674
+ if not chunks:
675
+ return "Could not extract meaningful text from the document."
676
+
677
+ # Create enhanced QA pairs
678
+ progress(0.5, desc="Creating QA dataset...")
679
+ if status_box:
680
+ status_box.update("Creating QA dataset...")
681
+ qa_pairs = create_qa_dataset(chunks)
682
+
683
+ print(f"Created {len(qa_pairs)} QA pairs for training")
684
+
685
+ # Debug: Print a sample of QA pairs to verify format
686
+ if qa_pairs:
687
+ print("\nSample QA pair for validation:")
688
+ sample = qa_pairs[0]
689
+ print(f"Question: {sample['question']}")
690
+ print(f"Context length: {len(sample['context'])} chars")
691
+ print(f"Answer: {sample['answer'][:50]}...")
692
+
693
+ # Create dataset
694
+ qa_dataset = QADataset(qa_pairs, tokenizer, max_length=MAX_SEQ_LEN)
695
+
696
+ # Fine-tune model
697
+ progress(0.7, desc="Starting fine-tuning...")
698
+ if status_box:
699
+ status_box.update("Starting fine-tuning...")
700
+ result = finetune_model(qa_dataset, progress, status_box)
701
+
702
+ # Clean up
703
+ try:
704
+ os.remove(temp_path)
705
+ os.rmdir(temp_dir)
706
+ except:
707
+ pass
708
+
709
+ return result
710
+
711
+ except Exception as e:
712
+ error_msg = f"Error processing document: {str(e)}"
713
+ print(error_msg)
714
+ import traceback
715
+ traceback.print_exc()
716
+ return error_msg
717
+
718
+ def generate_answer(question, status_box=None):
719
+ """Generate answer using fine-tuned Phi-2 model with improved response quality"""
720
+ global fine_tuned_model, tokenizer, document_text
721
+
722
+ if fine_tuned_model is None:
723
+ return "Please process a document first!"
724
+
725
+ if not question.strip():
726
+ return "Please enter a question."
727
+
728
+ try:
729
+ # Clear memory before generation
730
+ if torch.cuda.is_available():
731
+ torch.cuda.empty_cache()
732
+
733
+ # For better answers, use document context to help the model
734
+ # Find relevant context from document (simple keyword matching for efficiency)
735
+ keywords = re.findall(r'\b\w{5,}\b', question.lower())
736
+ context = document_text
737
+
738
+ # If document is very long, try to find relevant section
739
+ if len(document_text) > 2000 and keywords:
740
+ chunks = get_semantic_chunks(document_text, chunk_size=500, overlap=100)
741
+ relevant_chunks = []
742
+
743
+ for chunk in chunks:
744
+ score = sum(1 for keyword in keywords if keyword.lower() in chunk.lower())
745
+ if score > 0:
746
+ relevant_chunks.append((chunk, score))
747
+
748
+ relevant_chunks.sort(key=lambda x: x[1], reverse=True)
749
+
750
+ if relevant_chunks:
751
+ # Use top 2 most relevant chunks
752
+ context = " ".join([chunk for chunk, _ in relevant_chunks[:2]])
753
+
754
+ # Limit context length to fit in model's context window
755
+ context = context[:1500] # Limit to 1500 chars for prompt space
756
+
757
+ # Create Phi-2 optimized prompt
758
+ prompt = phi2_prompt_template(context, question)
759
+
760
+ # Ensure model is in evaluation mode
761
+ fine_tuned_model.eval()
762
+
763
+ # Tokenize input
764
+ inputs = tokenizer(prompt, return_tensors="pt").to(fine_tuned_model.device)
765
+
766
+ # Configure generation parameters optimized for Phi-2
767
+ with torch.no_grad():
768
+ outputs = fine_tuned_model.generate(
769
+ **inputs,
770
+ max_new_tokens=75, # Reduced from 150
771
+ do_sample=True,
772
+ temperature=0.7,
773
+ top_k=40,
774
+ top_p=0.85,
775
+ repetition_penalty=1.2,
776
+ pad_token_id=tokenizer.pad_token_id
777
+ )
778
+
779
+ # Decode response
780
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
781
+
782
+ # Extract only the generated answer part
783
+ if "Answer:" in response:
784
+ answer = response.split("Answer:")[-1].strip()
785
+ else:
786
+ answer = response
787
+
788
+ # If answer is too short or generic, try again with more temperature
789
+ if len(answer.split()) < 10 or "I don't have enough information" in answer:
790
+ with torch.no_grad():
791
+ outputs = fine_tuned_model.generate(
792
+ **inputs,
793
+ max_new_tokens=75, # Reduced from 150
794
+ do_sample=True,
795
+ temperature=0.9, # Higher temperature
796
+ top_k=40,
797
+ top_p=0.92,
798
+ repetition_penalty=1.2,
799
+ pad_token_id=tokenizer.pad_token_id
800
+ )
801
+
802
+ # Decode second attempt
803
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
804
+
805
+ # Extract answer
806
+ if "Answer:" in response:
807
+ answer = response.split("Answer:")[-1].strip()
808
+ else:
809
+ answer = response
810
+
811
+ return answer
812
+
813
+ except Exception as e:
814
+ error_msg = f"Error generating answer: {str(e)}"
815
+ print(error_msg)
816
+ return error_msg
817
+
818
+ # Create Gradio interface
819
+ with gr.Blocks(title="Phi-2 Document QA", theme=gr.themes.Soft()) as demo:
820
+ gr.Markdown("# 📚 Phi-2 Document Q&A System")
821
+ gr.Markdown("Specialized system for fine-tuning Microsoft's Phi-2 model on your documents")
822
+
823
+ with gr.Tab("Document Processing"):
824
+ file_input = gr.File(
825
+ label="Upload Document (PDF, DOCX, or TXT)",
826
+ file_types=[".pdf", ".docx", ".txt"],
827
+ type="binary"
828
+ )
829
+
830
+ with gr.Row():
831
+ load_model_btn = gr.Button("1. Load Phi-2 Model", variant="secondary")
832
+ process_btn = gr.Button("2. Process & Fine-tune Document", variant="primary")
833
+
834
+ status = gr.Textbox(
835
+ label="Status",
836
+ placeholder="First load the model, then upload a document and click 'Process & Fine-tune'",
837
+ lines=3
838
+ )
839
+
840
+ gr.Markdown("""
841
+ ### Tips for Best Results
842
+ - PDF, DOCX and TXT files are supported
843
+ - Keep documents under 10 pages for best results
844
+ - Processing time depends on document length and GPU availability
845
+ - For GPU usage in Colab: Runtime > Change runtime type > GPU
846
+ """)
847
+
848
+ with gr.Tab("Ask Questions"):
849
+ question_input = gr.Textbox(
850
+ label="Your Question",
851
+ placeholder="Ask about your document...",
852
+ lines=2
853
+ )
854
+
855
+ ask_btn = gr.Button("Get Answer", variant="primary")
856
+
857
+ answer_output = gr.Textbox(
858
+ label="Phi-2's Response",
859
+ placeholder="The answer will appear here after you ask a question",
860
+ lines=8
861
+ )
862
+
863
+ gr.Markdown("""
864
+ ### Example Questions
865
+ - "What is this document about?"
866
+ - "Summarize the key points in this document"
867
+ - "What does the document say about [specific topic]?"
868
+ - "Explain the relationship between [concept A] and [concept B]"
869
+ """)
870
+
871
+ # Set up events
872
+ load_model_btn.click(
873
+ fn=load_base_model,
874
+ outputs=[status]
875
+ )
876
+
877
+ process_btn.click(
878
+ fn=process_document,
879
+ inputs=[file_input],
880
+ outputs=[status]
881
+ )
882
+
883
+ ask_btn.click(
884
+ fn=generate_answer,
885
+ inputs=[question_input],
886
+ outputs=[answer_output]
887
+ )
888
+
889
+ # Launch the app
890
+ if __name__ == "__main__":
891
+ demo.launch(share=True)