Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
import warnings
|
5 |
+
import logging
|
6 |
+
from typing import List, Dict, Any, Optional
|
7 |
+
import tempfile
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
import gc
|
11 |
+
import spaces
|
12 |
+
|
13 |
+
# Set up logging
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO,
|
16 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
17 |
+
handlers=[
|
18 |
+
logging.FileHandler("debug.log"),
|
19 |
+
logging.StreamHandler()
|
20 |
+
]
|
21 |
+
)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
# Suppress warnings
|
25 |
+
warnings.filterwarnings("ignore")
|
26 |
+
|
27 |
+
def install_package(package: str, version: Optional[str] = None) -> None:
|
28 |
+
"""Install a Python package if not already installed"""
|
29 |
+
package_spec = f"{package}=={version}" if version else package
|
30 |
+
try:
|
31 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", package_spec])
|
32 |
+
print(f"Successfully installed {package_spec}")
|
33 |
+
except subprocess.CalledProcessError as e:
|
34 |
+
print(f"Failed to install {package_spec}: {e}")
|
35 |
+
raise
|
36 |
+
|
37 |
+
# Required packages - install these before importing
|
38 |
+
required_packages = {
|
39 |
+
"torch": None,
|
40 |
+
"gradio": "3.10.1",
|
41 |
+
"transformers": None,
|
42 |
+
"peft": None,
|
43 |
+
"bitsandbytes": None,
|
44 |
+
"PyPDF2": None,
|
45 |
+
"python-docx": None,
|
46 |
+
"accelerate": None,
|
47 |
+
"sentencepiece": None,
|
48 |
+
}
|
49 |
+
|
50 |
+
# Install required packages BEFORE importing them
|
51 |
+
for package, version in required_packages.items():
|
52 |
+
try:
|
53 |
+
__import__(package)
|
54 |
+
print(f"{package} is already installed.")
|
55 |
+
except ImportError:
|
56 |
+
print(f"Installing {package}...")
|
57 |
+
install_package(package, version)
|
58 |
+
|
59 |
+
# Now we can safely import all required modules
|
60 |
+
import torch
|
61 |
+
import transformers
|
62 |
+
import gradio as gr
|
63 |
+
from transformers import (
|
64 |
+
AutoTokenizer, AutoModelForCausalLM,
|
65 |
+
TrainingArguments, Trainer, TrainerCallback,
|
66 |
+
BitsAndBytesConfig
|
67 |
+
)
|
68 |
+
from peft import (
|
69 |
+
LoraConfig,
|
70 |
+
prepare_model_for_kbit_training,
|
71 |
+
get_peft_model
|
72 |
+
)
|
73 |
+
import PyPDF2
|
74 |
+
import docx
|
75 |
+
import numpy as np
|
76 |
+
from tqdm import tqdm
|
77 |
+
from torch.utils.data import Dataset as TorchDataset
|
78 |
+
|
79 |
+
# Suppress transformers warnings
|
80 |
+
transformers.logging.set_verbosity_error()
|
81 |
+
|
82 |
+
# Check GPU availability
|
83 |
+
if torch.cuda.is_available():
|
84 |
+
DEVICE = "cuda"
|
85 |
+
print(f"GPU found: {torch.cuda.get_device_name(0)}")
|
86 |
+
print(f"CUDA version: {torch.version.cuda}")
|
87 |
+
else:
|
88 |
+
DEVICE = "cpu"
|
89 |
+
print("No GPU found, using CPU. Fine-tuning will be much slower.")
|
90 |
+
print("For better performance, use Google Colab with GPU runtime (Runtime > Change runtime type > GPU)")
|
91 |
+
|
92 |
+
# Constants specific to Phi-2
|
93 |
+
MODEL_KEY = "microsoft/phi-2"
|
94 |
+
MAX_SEQ_LEN = 512 # Reduced from 1024 for much lighter memory usage
|
95 |
+
# FIX: Updated target modules for Phi-2
|
96 |
+
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "dense"] # Correct modules for Phi-2
|
97 |
+
|
98 |
+
# Initialize model and tokenizer
|
99 |
+
model = None
|
100 |
+
tokenizer = None
|
101 |
+
fine_tuned_model = None
|
102 |
+
document_text = "" # Store document content for context
|
103 |
+
|
104 |
+
def load_base_model() -> str:
|
105 |
+
"""Load Phi-2 with 8-bit quantization instead of 4-bit for faster training"""
|
106 |
+
global model, tokenizer
|
107 |
+
|
108 |
+
if torch.cuda.is_available():
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
gc.collect()
|
111 |
+
|
112 |
+
try:
|
113 |
+
# Use 8-bit quantization (faster to train than 4-bit)
|
114 |
+
if DEVICE == "cuda":
|
115 |
+
bnb_config = BitsAndBytesConfig(
|
116 |
+
load_in_8bit=True,
|
117 |
+
llm_int8_threshold=6.0,
|
118 |
+
llm_int8_has_fp16_weight=False
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
bnb_config = None
|
122 |
+
|
123 |
+
# Load tokenizer with Phi-2 specific settings
|
124 |
+
print("Loading Phi-2 tokenizer...")
|
125 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
126 |
+
MODEL_KEY,
|
127 |
+
trust_remote_code=True,
|
128 |
+
padding_side="right"
|
129 |
+
)
|
130 |
+
|
131 |
+
# Ensure pad token is properly set
|
132 |
+
if tokenizer.pad_token is None:
|
133 |
+
tokenizer.pad_token = tokenizer.eos_token
|
134 |
+
|
135 |
+
# Load model with Phi-2 specific configuration
|
136 |
+
print("Loading Phi-2 model... (this may take a few minutes)")
|
137 |
+
if DEVICE == "cuda":
|
138 |
+
model = AutoModelForCausalLM.from_pretrained(
|
139 |
+
MODEL_KEY,
|
140 |
+
quantization_config=bnb_config,
|
141 |
+
device_map="auto",
|
142 |
+
torch_dtype=torch.float16,
|
143 |
+
trust_remote_code=True,
|
144 |
+
low_cpu_mem_usage=True
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
model = AutoModelForCausalLM.from_pretrained(
|
148 |
+
MODEL_KEY,
|
149 |
+
torch_dtype=torch.float32,
|
150 |
+
trust_remote_code=True,
|
151 |
+
low_cpu_mem_usage=True
|
152 |
+
).to(DEVICE)
|
153 |
+
|
154 |
+
print("Phi-2 (2.7B) model loaded successfully!")
|
155 |
+
return "Phi-2 (2.7B) model loaded successfully! Ready to process documents."
|
156 |
+
|
157 |
+
except Exception as e:
|
158 |
+
error_msg = f"Error loading model: {str(e)}"
|
159 |
+
print(error_msg)
|
160 |
+
return error_msg
|
161 |
+
|
162 |
+
def phi2_prompt_template(context: str, question: str) -> str:
|
163 |
+
"""
|
164 |
+
Create a prompt optimized for Phi-2
|
165 |
+
Phi-2 responds well to clear instruction formatting
|
166 |
+
"""
|
167 |
+
return f"""Instruction: Answer the question accurately based on the context provided.
|
168 |
+
Context: {context}
|
169 |
+
Question: {question}
|
170 |
+
Answer:"""
|
171 |
+
|
172 |
+
def process_pdf(file_path: str) -> str:
|
173 |
+
"""Extract text from PDF file"""
|
174 |
+
text = ""
|
175 |
+
try:
|
176 |
+
with open(file_path, 'rb') as file:
|
177 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
178 |
+
total_pages = len(pdf_reader.pages)
|
179 |
+
# Process at most 30 pages to avoid memory issues
|
180 |
+
pages_to_process = min(total_pages, 30)
|
181 |
+
for i in range(pages_to_process):
|
182 |
+
page = pdf_reader.pages[i]
|
183 |
+
page_text = page.extract_text() or ""
|
184 |
+
text += page_text + "\n"
|
185 |
+
|
186 |
+
if total_pages > pages_to_process:
|
187 |
+
text += f"\n[Note: Only the first {pages_to_process} pages were processed due to size limitations.]"
|
188 |
+
except Exception as e:
|
189 |
+
print(f"Error processing PDF: {str(e)}")
|
190 |
+
return text
|
191 |
+
|
192 |
+
def process_docx(file_path: str) -> str:
|
193 |
+
"""Extract text from DOCX file"""
|
194 |
+
try:
|
195 |
+
doc = docx.Document(file_path)
|
196 |
+
text = "\n".join([para.text for para in doc.paragraphs])
|
197 |
+
return text
|
198 |
+
except Exception as e:
|
199 |
+
print(f"Error processing DOCX: {str(e)}")
|
200 |
+
return ""
|
201 |
+
|
202 |
+
def process_txt(file_path: str) -> str:
|
203 |
+
"""Extract text from TXT file"""
|
204 |
+
try:
|
205 |
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
|
206 |
+
text = file.read()
|
207 |
+
return text
|
208 |
+
except Exception as e:
|
209 |
+
print(f"Error processing TXT: {str(e)}")
|
210 |
+
return ""
|
211 |
+
|
212 |
+
def preprocess_text(text: str) -> str:
|
213 |
+
"""Clean and preprocess text"""
|
214 |
+
if not text:
|
215 |
+
return ""
|
216 |
+
# Remove extra whitespace
|
217 |
+
text = re.sub(r'\s+', ' ', text)
|
218 |
+
# Remove special characters that may cause issues
|
219 |
+
text = re.sub(r'[^\w\s.,;:!?\'\"()-]', '', text)
|
220 |
+
return text.strip()
|
221 |
+
|
222 |
+
def get_semantic_chunks(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
|
223 |
+
"""More efficient semantic chunking"""
|
224 |
+
if not text:
|
225 |
+
return []
|
226 |
+
|
227 |
+
# Simple sentence splitting for speed
|
228 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
229 |
+
chunks = []
|
230 |
+
current_chunk = []
|
231 |
+
current_length = 0
|
232 |
+
|
233 |
+
for sentence in sentences:
|
234 |
+
words = sentence.split()
|
235 |
+
if current_length + len(words) <= chunk_size:
|
236 |
+
current_chunk.append(sentence)
|
237 |
+
current_length += len(words)
|
238 |
+
else:
|
239 |
+
if current_chunk:
|
240 |
+
chunks.append(' '.join(current_chunk))
|
241 |
+
current_chunk = [sentence]
|
242 |
+
current_length = len(words)
|
243 |
+
|
244 |
+
if current_chunk:
|
245 |
+
chunks.append(' '.join(current_chunk))
|
246 |
+
|
247 |
+
# Limit to just 5 chunks for much faster processing
|
248 |
+
if len(chunks) > 5:
|
249 |
+
indices = np.linspace(0, len(chunks)-1, 5, dtype=int)
|
250 |
+
chunks = [chunks[i] for i in indices]
|
251 |
+
|
252 |
+
return chunks
|
253 |
+
|
254 |
+
def create_qa_dataset(document_chunks: List[str]) -> List[Dict[str, str]]:
|
255 |
+
"""Create comprehensive QA pairs from document chunks for better fine-tuning"""
|
256 |
+
qa_pairs = []
|
257 |
+
|
258 |
+
# Document-level questions
|
259 |
+
full_text = " ".join(document_chunks[:5]) # Use beginning of document for overview
|
260 |
+
qa_pairs.append({
|
261 |
+
"question": "What is this document about?",
|
262 |
+
"context": full_text,
|
263 |
+
"answer": "Based on my analysis, this document discusses..." # Empty template for model to learn
|
264 |
+
})
|
265 |
+
|
266 |
+
qa_pairs.append({
|
267 |
+
"question": "Summarize the key points of this document.",
|
268 |
+
"context": full_text,
|
269 |
+
"answer": "The key points of this document are..."
|
270 |
+
})
|
271 |
+
|
272 |
+
# Process each chunk for specific QA pairs
|
273 |
+
for i, chunk in enumerate(document_chunks):
|
274 |
+
if not chunk or len(chunk) < 100: # Skip very short chunks
|
275 |
+
continue
|
276 |
+
|
277 |
+
# Context-specific questions
|
278 |
+
chunk_index = i + 1 # 1-indexed for readability
|
279 |
+
|
280 |
+
# Basic factual questions about chunk content
|
281 |
+
qa_pairs.append({
|
282 |
+
"question": f"What information is contained in section {chunk_index}?",
|
283 |
+
"context": chunk,
|
284 |
+
"answer": f"Section {chunk_index} contains information about..."
|
285 |
+
})
|
286 |
+
|
287 |
+
# Entity-based questions - find names, organizations, technical terms
|
288 |
+
entities = set(re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', chunk))
|
289 |
+
technical_terms = set(re.findall(r'\b[A-Za-z]+-?[A-Za-z]+\b', chunk))
|
290 |
+
|
291 |
+
# Filter to meaningful entities (longer than 3 chars)
|
292 |
+
entities = [e for e in entities if len(e) > 3][:2] # Limit to 2 entity questions per chunk
|
293 |
+
|
294 |
+
for entity in entities:
|
295 |
+
qa_pairs.append({
|
296 |
+
"question": f"What does the document say about {entity}?",
|
297 |
+
"context": chunk,
|
298 |
+
"answer": f"Regarding {entity}, the document states that..."
|
299 |
+
})
|
300 |
+
|
301 |
+
# Specific content questions
|
302 |
+
sentences = re.split(r'(?<=[.!?])\s+', chunk)
|
303 |
+
key_sentences = [s for s in sentences if len(s.split()) > 8][:2] # Focus on substantive sentences
|
304 |
+
|
305 |
+
for sentence in key_sentences:
|
306 |
+
# Create question from sentence by identifying subject
|
307 |
+
subject_match = re.search(r'^(The|A|An|This|These|Those|Some|Any|Many|Few|All|Most)?\s*([A-Za-z\s]+?)\s+(is|are|was|were|has|have|had|can|could|will|would|may|might)', sentence, re.IGNORECASE)
|
308 |
+
if subject_match:
|
309 |
+
subject = subject_match.group(2).strip()
|
310 |
+
if len(subject) > 2:
|
311 |
+
qa_pairs.append({
|
312 |
+
"question": f"What information is provided about {subject}?",
|
313 |
+
"context": chunk,
|
314 |
+
"answer": sentence
|
315 |
+
})
|
316 |
+
|
317 |
+
# Add relationship questions between concepts
|
318 |
+
if i < len(document_chunks) - 1:
|
319 |
+
next_chunk = document_chunks[i+1]
|
320 |
+
qa_pairs.append({
|
321 |
+
"question": f"How does the information in section {chunk_index} relate to section {chunk_index+1}?",
|
322 |
+
"context": chunk + " " + next_chunk,
|
323 |
+
"answer": f"Section {chunk_index} discusses... while section {chunk_index+1} covers... The relationship between them is..."
|
324 |
+
})
|
325 |
+
|
326 |
+
# Limit to 5 examples max for lighter memory usage
|
327 |
+
if len(qa_pairs) > 5:
|
328 |
+
import random
|
329 |
+
random.shuffle(qa_pairs)
|
330 |
+
qa_pairs = qa_pairs[:5]
|
331 |
+
|
332 |
+
return qa_pairs
|
333 |
+
|
334 |
+
class QADataset(TorchDataset):
|
335 |
+
"""PyTorch dataset specialized for Phi-2 QA fine-tuning"""
|
336 |
+
def __init__(self, qa_pairs: List[Dict[str, str]], tokenizer, max_length: int = MAX_SEQ_LEN):
|
337 |
+
self.qa_pairs = qa_pairs
|
338 |
+
self.tokenizer = tokenizer
|
339 |
+
self.max_length = max_length
|
340 |
+
|
341 |
+
# Verify dataset structure
|
342 |
+
self.validate_dataset()
|
343 |
+
|
344 |
+
def validate_dataset(self):
|
345 |
+
"""Verify that the dataset has proper structure"""
|
346 |
+
if not self.qa_pairs:
|
347 |
+
print("Warning: Empty dataset!")
|
348 |
+
return
|
349 |
+
|
350 |
+
required_keys = ["question", "context", "answer"]
|
351 |
+
for i, item in enumerate(self.qa_pairs[:5]): # Check first 5 examples
|
352 |
+
missing = [k for k in required_keys if k not in item]
|
353 |
+
if missing:
|
354 |
+
print(f"Warning: Example {i} missing keys: {missing}")
|
355 |
+
|
356 |
+
# Check for empty values
|
357 |
+
empty = [k for k in required_keys if k in item and not item[k]]
|
358 |
+
if empty:
|
359 |
+
print(f"Warning: Example {i} has empty values for: {empty}")
|
360 |
+
|
361 |
+
def __len__(self):
|
362 |
+
return len(self.qa_pairs)
|
363 |
+
|
364 |
+
def __getitem__(self, idx):
|
365 |
+
qa_pair = self.qa_pairs[idx]
|
366 |
+
|
367 |
+
# Format prompt using Phi-2 template
|
368 |
+
context = qa_pair['context']
|
369 |
+
question = qa_pair['question']
|
370 |
+
answer = qa_pair['answer']
|
371 |
+
|
372 |
+
# Build Phi-2 specific prompt
|
373 |
+
prompt = phi2_prompt_template(context, question)
|
374 |
+
|
375 |
+
# Concatenate prompt and answer
|
376 |
+
sequence = f"{prompt} {answer}"
|
377 |
+
|
378 |
+
try:
|
379 |
+
# Tokenize with proper handling
|
380 |
+
encoded = self.tokenizer(
|
381 |
+
sequence,
|
382 |
+
truncation=True,
|
383 |
+
max_length=self.max_length,
|
384 |
+
padding="max_length",
|
385 |
+
return_tensors="pt"
|
386 |
+
)
|
387 |
+
|
388 |
+
# Extract tensors
|
389 |
+
input_ids = encoded["input_ids"].squeeze(0)
|
390 |
+
attention_mask = encoded["attention_mask"].squeeze(0)
|
391 |
+
|
392 |
+
# Create labels
|
393 |
+
labels = input_ids.clone()
|
394 |
+
|
395 |
+
# Calculate prompt length accurately
|
396 |
+
prompt_encoded = self.tokenizer(prompt, add_special_tokens=False)
|
397 |
+
prompt_length = len(prompt_encoded["input_ids"])
|
398 |
+
|
399 |
+
# Ensure prompt_length doesn't exceed labels length
|
400 |
+
prompt_length = min(prompt_length, len(labels))
|
401 |
+
|
402 |
+
# Set labels for prompt portion to -100 (ignored in loss calculation)
|
403 |
+
labels[:prompt_length] = -100
|
404 |
+
|
405 |
+
return {
|
406 |
+
"input_ids": input_ids,
|
407 |
+
"attention_mask": attention_mask,
|
408 |
+
"labels": labels
|
409 |
+
}
|
410 |
+
|
411 |
+
except Exception as e:
|
412 |
+
print(f"Error processing sample {idx}: {e}")
|
413 |
+
# Return dummy sample as fallback
|
414 |
+
return {
|
415 |
+
"input_ids": torch.zeros(self.max_length, dtype=torch.long),
|
416 |
+
"attention_mask": torch.zeros(self.max_length, dtype=torch.long),
|
417 |
+
"labels": torch.zeros(self.max_length, dtype=torch.long)
|
418 |
+
}
|
419 |
+
|
420 |
+
def clear_gpu_memory():
|
421 |
+
"""Clear GPU memory to prevent OOM errors"""
|
422 |
+
if torch.cuda.is_available():
|
423 |
+
torch.cuda.empty_cache()
|
424 |
+
gc.collect()
|
425 |
+
|
426 |
+
class ProgressCallback(TrainerCallback):
|
427 |
+
def __init__(self, progress, status_box=None):
|
428 |
+
self.progress = progress
|
429 |
+
self.status_box = status_box
|
430 |
+
self.current_step = 0
|
431 |
+
self.total_steps = 0
|
432 |
+
|
433 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
434 |
+
self.total_steps = state.max_steps
|
435 |
+
|
436 |
+
def on_step_end(self, args, state, control, **kwargs):
|
437 |
+
self.current_step = state.global_step
|
438 |
+
progress_percent = self.current_step / self.total_steps
|
439 |
+
self.progress(0.4 + (0.5 * progress_percent),
|
440 |
+
desc=f"Epoch {state.epoch}/{args.num_train_epochs} | Step {self.current_step}/{self.total_steps}")
|
441 |
+
if self.status_box:
|
442 |
+
self.status_box.update(f"Training in progress: Epoch {state.epoch}/{args.num_train_epochs} | Step {self.current_step}/{self.total_steps}")
|
443 |
+
|
444 |
+
def create_deepspeed_config():
|
445 |
+
"""Create DeepSpeed config for faster training"""
|
446 |
+
return {
|
447 |
+
"fp16": {
|
448 |
+
"enabled": True
|
449 |
+
},
|
450 |
+
"zero_optimization": {
|
451 |
+
"stage": 2,
|
452 |
+
"offload_optimizer": {
|
453 |
+
"device": "cpu",
|
454 |
+
"pin_memory": True
|
455 |
+
},
|
456 |
+
"allgather_partitions": True,
|
457 |
+
"allgather_bucket_size": 5e8,
|
458 |
+
"reduce_scatter": True,
|
459 |
+
"reduce_bucket_size": 5e8,
|
460 |
+
"overlap_comm": True,
|
461 |
+
"contiguous_gradients": True
|
462 |
+
},
|
463 |
+
"optimizer": {
|
464 |
+
"type": "AdamW",
|
465 |
+
"params": {
|
466 |
+
"lr": 2e-4,
|
467 |
+
"betas": [0.9, 0.999],
|
468 |
+
"eps": 1e-8,
|
469 |
+
"weight_decay": 0.01
|
470 |
+
}
|
471 |
+
},
|
472 |
+
"scheduler": {
|
473 |
+
"type": "WarmupLR",
|
474 |
+
"params": {
|
475 |
+
"warmup_min_lr": 0,
|
476 |
+
"warmup_max_lr": 2e-4,
|
477 |
+
"warmup_num_steps": 50
|
478 |
+
}
|
479 |
+
},
|
480 |
+
"train_batch_size": 1,
|
481 |
+
"train_micro_batch_size_per_gpu": 1,
|
482 |
+
"gradient_accumulation_steps": 1,
|
483 |
+
"gradient_clipping": 0.5,
|
484 |
+
"steps_per_print": 10
|
485 |
+
}
|
486 |
+
|
487 |
+
def finetune_model(qa_dataset, progress=gr.Progress(), status_box=None):
|
488 |
+
"""Fine-tune Phi-2 using optimized LoRA parameters"""
|
489 |
+
global model, tokenizer, fine_tuned_model
|
490 |
+
|
491 |
+
if model is None:
|
492 |
+
return "Please load the base model first."
|
493 |
+
|
494 |
+
if len(qa_dataset) == 0:
|
495 |
+
return "No training data created. Please check your document."
|
496 |
+
|
497 |
+
try:
|
498 |
+
progress(0.1, desc="Preparing model for fine-tuning...")
|
499 |
+
if status_box:
|
500 |
+
status_box.update("Preparing model for fine-tuning...")
|
501 |
+
|
502 |
+
# Clear GPU memory
|
503 |
+
clear_gpu_memory()
|
504 |
+
|
505 |
+
# Prepare model for 8-bit training if using GPU
|
506 |
+
if DEVICE == "cuda":
|
507 |
+
training_model = prepare_model_for_kbit_training(model)
|
508 |
+
else:
|
509 |
+
training_model = model
|
510 |
+
|
511 |
+
# Add this line to fix the gradient error
|
512 |
+
training_model.enable_input_require_grads()
|
513 |
+
|
514 |
+
# Configure LoRA for Phi-2
|
515 |
+
peft_config = LoraConfig(
|
516 |
+
r=2, # Reduced rank for lighter training
|
517 |
+
lora_alpha=4, # Reduced alpha
|
518 |
+
lora_dropout=0.05, # Added small dropout for regularization
|
519 |
+
bias="none",
|
520 |
+
task_type="CAUSAL_LM",
|
521 |
+
target_modules=LORA_TARGET_MODULES # Fixed Phi-2 modules
|
522 |
+
)
|
523 |
+
|
524 |
+
# Apply LoRA to model
|
525 |
+
lora_model = get_peft_model(training_model, peft_config)
|
526 |
+
|
527 |
+
# Print trainable parameters
|
528 |
+
trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
|
529 |
+
all_params = sum(p.numel() for p in lora_model.parameters())
|
530 |
+
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/all_params:.2%} of {all_params:,} total)")
|
531 |
+
|
532 |
+
# Enable gradient checkpointing for memory efficiency
|
533 |
+
if hasattr(lora_model, "gradient_checkpointing_enable"):
|
534 |
+
lora_model.gradient_checkpointing_enable()
|
535 |
+
print("Gradient checkpointing enabled")
|
536 |
+
|
537 |
+
# Create training arguments optimized for Phi-2
|
538 |
+
training_args = TrainingArguments(
|
539 |
+
output_dir="./results",
|
540 |
+
num_train_epochs=2, # Set to 2 as requested
|
541 |
+
per_device_train_batch_size=1,
|
542 |
+
gradient_accumulation_steps=1,
|
543 |
+
learning_rate=1e-4, # Reduced from 2e-4 for stability
|
544 |
+
lr_scheduler_type="constant", # Simplified scheduler
|
545 |
+
warmup_ratio=0.05, # Slight increase in warmup
|
546 |
+
weight_decay=0.01,
|
547 |
+
logging_steps=1,
|
548 |
+
max_grad_norm=0.3, # Reduced from 0.5 for better gradient stability
|
549 |
+
save_strategy="no",
|
550 |
+
report_to="none",
|
551 |
+
remove_unused_columns=False,
|
552 |
+
fp16=(DEVICE == "cuda"),
|
553 |
+
no_cuda=(DEVICE == "cpu"),
|
554 |
+
optim="adamw_torch", # Use standard optimizer instead of fused for stability
|
555 |
+
gradient_checkpointing=True
|
556 |
+
)
|
557 |
+
|
558 |
+
# Add DeepSpeed if on CUDA
|
559 |
+
if DEVICE == "cuda":
|
560 |
+
training_args.deepspeed = create_deepspeed_config()
|
561 |
+
|
562 |
+
# Create data collator that doesn't move tensors to device yet
|
563 |
+
def collate_fn(features):
|
564 |
+
batch = {}
|
565 |
+
for key in features[0].keys():
|
566 |
+
if key in ["input_ids", "attention_mask", "labels"]:
|
567 |
+
batch[key] = torch.stack([f[key] for f in features])
|
568 |
+
return batch
|
569 |
+
|
570 |
+
progress(0.3, desc="Setting up trainer...")
|
571 |
+
if status_box:
|
572 |
+
status_box.update("Setting up trainer...")
|
573 |
+
|
574 |
+
# Create trainer
|
575 |
+
trainer = Trainer(
|
576 |
+
model=lora_model,
|
577 |
+
args=training_args,
|
578 |
+
train_dataset=qa_dataset,
|
579 |
+
data_collator=collate_fn,
|
580 |
+
callbacks=[ProgressCallback(progress, status_box)] # Add both callbacks
|
581 |
+
)
|
582 |
+
|
583 |
+
# Start training
|
584 |
+
progress(0.4, desc="Initializing training...")
|
585 |
+
if status_box:
|
586 |
+
status_box.update("Initializing training...")
|
587 |
+
print("Starting training...")
|
588 |
+
trainer.train()
|
589 |
+
|
590 |
+
# Set fine-tuned model
|
591 |
+
fine_tuned_model = lora_model
|
592 |
+
|
593 |
+
# Put model in evaluation mode
|
594 |
+
fine_tuned_model.eval()
|
595 |
+
|
596 |
+
# Clear memory
|
597 |
+
clear_gpu_memory()
|
598 |
+
|
599 |
+
return "Fine-tuning completed successfully! You can now ask questions about your document."
|
600 |
+
|
601 |
+
except Exception as e:
|
602 |
+
error_msg = f"Error during fine-tuning: {str(e)}"
|
603 |
+
print(error_msg)
|
604 |
+
import traceback
|
605 |
+
traceback.print_exc()
|
606 |
+
|
607 |
+
# Try to clean up memory
|
608 |
+
try:
|
609 |
+
clear_gpu_memory()
|
610 |
+
except:
|
611 |
+
pass
|
612 |
+
|
613 |
+
return error_msg
|
614 |
+
|
615 |
+
def process_document(file_obj, progress=gr.Progress(), status_box=None):
|
616 |
+
"""Process uploaded document and prepare dataset for fine-tuning"""
|
617 |
+
global model, tokenizer, document_text
|
618 |
+
|
619 |
+
progress(0, desc="Processing document...")
|
620 |
+
if status_box:
|
621 |
+
status_box.update("Processing document...")
|
622 |
+
|
623 |
+
if not file_obj:
|
624 |
+
return "Please upload a document first."
|
625 |
+
|
626 |
+
try:
|
627 |
+
# Create temp directory for file
|
628 |
+
temp_dir = tempfile.mkdtemp()
|
629 |
+
|
630 |
+
# Get file name
|
631 |
+
file_name = getattr(file_obj, 'name', 'uploaded_file')
|
632 |
+
if not isinstance(file_name, str):
|
633 |
+
file_name = "uploaded_file.txt" # Default name
|
634 |
+
|
635 |
+
# Ensure file has extension
|
636 |
+
if '.' not in file_name:
|
637 |
+
file_name = file_name + '.txt'
|
638 |
+
|
639 |
+
temp_path = os.path.join(temp_dir, file_name)
|
640 |
+
|
641 |
+
# Get file content
|
642 |
+
if hasattr(file_obj, 'read'):
|
643 |
+
file_content = file_obj.read()
|
644 |
+
else:
|
645 |
+
file_content = file_obj
|
646 |
+
|
647 |
+
with open(temp_path, 'wb') as f:
|
648 |
+
f.write(file_content)
|
649 |
+
|
650 |
+
# Extract text based on file extension
|
651 |
+
file_extension = os.path.splitext(file_name)[1].lower()
|
652 |
+
|
653 |
+
if file_extension == '.pdf':
|
654 |
+
text = process_pdf(temp_path)
|
655 |
+
elif file_extension in ['.docx', '.doc']:
|
656 |
+
text = process_docx(temp_path)
|
657 |
+
elif file_extension == '.txt' or True: # Default to txt for unknown extensions
|
658 |
+
text = process_txt(temp_path)
|
659 |
+
|
660 |
+
# Check if text was extracted
|
661 |
+
if not text or len(text) < 50:
|
662 |
+
return "Could not extract sufficient text from the document. Please check the file."
|
663 |
+
|
664 |
+
# Save document text for context window during inference
|
665 |
+
document_text = text
|
666 |
+
|
667 |
+
# Preprocess and chunk the document
|
668 |
+
progress(0.3, desc="Preprocessing document...")
|
669 |
+
if status_box:
|
670 |
+
status_box.update("Preprocessing document...")
|
671 |
+
text = preprocess_text(text)
|
672 |
+
chunks = get_semantic_chunks(text)
|
673 |
+
|
674 |
+
if not chunks:
|
675 |
+
return "Could not extract meaningful text from the document."
|
676 |
+
|
677 |
+
# Create enhanced QA pairs
|
678 |
+
progress(0.5, desc="Creating QA dataset...")
|
679 |
+
if status_box:
|
680 |
+
status_box.update("Creating QA dataset...")
|
681 |
+
qa_pairs = create_qa_dataset(chunks)
|
682 |
+
|
683 |
+
print(f"Created {len(qa_pairs)} QA pairs for training")
|
684 |
+
|
685 |
+
# Debug: Print a sample of QA pairs to verify format
|
686 |
+
if qa_pairs:
|
687 |
+
print("\nSample QA pair for validation:")
|
688 |
+
sample = qa_pairs[0]
|
689 |
+
print(f"Question: {sample['question']}")
|
690 |
+
print(f"Context length: {len(sample['context'])} chars")
|
691 |
+
print(f"Answer: {sample['answer'][:50]}...")
|
692 |
+
|
693 |
+
# Create dataset
|
694 |
+
qa_dataset = QADataset(qa_pairs, tokenizer, max_length=MAX_SEQ_LEN)
|
695 |
+
|
696 |
+
# Fine-tune model
|
697 |
+
progress(0.7, desc="Starting fine-tuning...")
|
698 |
+
if status_box:
|
699 |
+
status_box.update("Starting fine-tuning...")
|
700 |
+
result = finetune_model(qa_dataset, progress, status_box)
|
701 |
+
|
702 |
+
# Clean up
|
703 |
+
try:
|
704 |
+
os.remove(temp_path)
|
705 |
+
os.rmdir(temp_dir)
|
706 |
+
except:
|
707 |
+
pass
|
708 |
+
|
709 |
+
return result
|
710 |
+
|
711 |
+
except Exception as e:
|
712 |
+
error_msg = f"Error processing document: {str(e)}"
|
713 |
+
print(error_msg)
|
714 |
+
import traceback
|
715 |
+
traceback.print_exc()
|
716 |
+
return error_msg
|
717 |
+
|
718 |
+
def generate_answer(question, status_box=None):
|
719 |
+
"""Generate answer using fine-tuned Phi-2 model with improved response quality"""
|
720 |
+
global fine_tuned_model, tokenizer, document_text
|
721 |
+
|
722 |
+
if fine_tuned_model is None:
|
723 |
+
return "Please process a document first!"
|
724 |
+
|
725 |
+
if not question.strip():
|
726 |
+
return "Please enter a question."
|
727 |
+
|
728 |
+
try:
|
729 |
+
# Clear memory before generation
|
730 |
+
if torch.cuda.is_available():
|
731 |
+
torch.cuda.empty_cache()
|
732 |
+
|
733 |
+
# For better answers, use document context to help the model
|
734 |
+
# Find relevant context from document (simple keyword matching for efficiency)
|
735 |
+
keywords = re.findall(r'\b\w{5,}\b', question.lower())
|
736 |
+
context = document_text
|
737 |
+
|
738 |
+
# If document is very long, try to find relevant section
|
739 |
+
if len(document_text) > 2000 and keywords:
|
740 |
+
chunks = get_semantic_chunks(document_text, chunk_size=500, overlap=100)
|
741 |
+
relevant_chunks = []
|
742 |
+
|
743 |
+
for chunk in chunks:
|
744 |
+
score = sum(1 for keyword in keywords if keyword.lower() in chunk.lower())
|
745 |
+
if score > 0:
|
746 |
+
relevant_chunks.append((chunk, score))
|
747 |
+
|
748 |
+
relevant_chunks.sort(key=lambda x: x[1], reverse=True)
|
749 |
+
|
750 |
+
if relevant_chunks:
|
751 |
+
# Use top 2 most relevant chunks
|
752 |
+
context = " ".join([chunk for chunk, _ in relevant_chunks[:2]])
|
753 |
+
|
754 |
+
# Limit context length to fit in model's context window
|
755 |
+
context = context[:1500] # Limit to 1500 chars for prompt space
|
756 |
+
|
757 |
+
# Create Phi-2 optimized prompt
|
758 |
+
prompt = phi2_prompt_template(context, question)
|
759 |
+
|
760 |
+
# Ensure model is in evaluation mode
|
761 |
+
fine_tuned_model.eval()
|
762 |
+
|
763 |
+
# Tokenize input
|
764 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(fine_tuned_model.device)
|
765 |
+
|
766 |
+
# Configure generation parameters optimized for Phi-2
|
767 |
+
with torch.no_grad():
|
768 |
+
outputs = fine_tuned_model.generate(
|
769 |
+
**inputs,
|
770 |
+
max_new_tokens=75, # Reduced from 150
|
771 |
+
do_sample=True,
|
772 |
+
temperature=0.7,
|
773 |
+
top_k=40,
|
774 |
+
top_p=0.85,
|
775 |
+
repetition_penalty=1.2,
|
776 |
+
pad_token_id=tokenizer.pad_token_id
|
777 |
+
)
|
778 |
+
|
779 |
+
# Decode response
|
780 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
781 |
+
|
782 |
+
# Extract only the generated answer part
|
783 |
+
if "Answer:" in response:
|
784 |
+
answer = response.split("Answer:")[-1].strip()
|
785 |
+
else:
|
786 |
+
answer = response
|
787 |
+
|
788 |
+
# If answer is too short or generic, try again with more temperature
|
789 |
+
if len(answer.split()) < 10 or "I don't have enough information" in answer:
|
790 |
+
with torch.no_grad():
|
791 |
+
outputs = fine_tuned_model.generate(
|
792 |
+
**inputs,
|
793 |
+
max_new_tokens=75, # Reduced from 150
|
794 |
+
do_sample=True,
|
795 |
+
temperature=0.9, # Higher temperature
|
796 |
+
top_k=40,
|
797 |
+
top_p=0.92,
|
798 |
+
repetition_penalty=1.2,
|
799 |
+
pad_token_id=tokenizer.pad_token_id
|
800 |
+
)
|
801 |
+
|
802 |
+
# Decode second attempt
|
803 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
804 |
+
|
805 |
+
# Extract answer
|
806 |
+
if "Answer:" in response:
|
807 |
+
answer = response.split("Answer:")[-1].strip()
|
808 |
+
else:
|
809 |
+
answer = response
|
810 |
+
|
811 |
+
return answer
|
812 |
+
|
813 |
+
except Exception as e:
|
814 |
+
error_msg = f"Error generating answer: {str(e)}"
|
815 |
+
print(error_msg)
|
816 |
+
return error_msg
|
817 |
+
|
818 |
+
# Create Gradio interface
|
819 |
+
with gr.Blocks(title="Phi-2 Document QA", theme=gr.themes.Soft()) as demo:
|
820 |
+
gr.Markdown("# 📚 Phi-2 Document Q&A System")
|
821 |
+
gr.Markdown("Specialized system for fine-tuning Microsoft's Phi-2 model on your documents")
|
822 |
+
|
823 |
+
with gr.Tab("Document Processing"):
|
824 |
+
file_input = gr.File(
|
825 |
+
label="Upload Document (PDF, DOCX, or TXT)",
|
826 |
+
file_types=[".pdf", ".docx", ".txt"],
|
827 |
+
type="binary"
|
828 |
+
)
|
829 |
+
|
830 |
+
with gr.Row():
|
831 |
+
load_model_btn = gr.Button("1. Load Phi-2 Model", variant="secondary")
|
832 |
+
process_btn = gr.Button("2. Process & Fine-tune Document", variant="primary")
|
833 |
+
|
834 |
+
status = gr.Textbox(
|
835 |
+
label="Status",
|
836 |
+
placeholder="First load the model, then upload a document and click 'Process & Fine-tune'",
|
837 |
+
lines=3
|
838 |
+
)
|
839 |
+
|
840 |
+
gr.Markdown("""
|
841 |
+
### Tips for Best Results
|
842 |
+
- PDF, DOCX and TXT files are supported
|
843 |
+
- Keep documents under 10 pages for best results
|
844 |
+
- Processing time depends on document length and GPU availability
|
845 |
+
- For GPU usage in Colab: Runtime > Change runtime type > GPU
|
846 |
+
""")
|
847 |
+
|
848 |
+
with gr.Tab("Ask Questions"):
|
849 |
+
question_input = gr.Textbox(
|
850 |
+
label="Your Question",
|
851 |
+
placeholder="Ask about your document...",
|
852 |
+
lines=2
|
853 |
+
)
|
854 |
+
|
855 |
+
ask_btn = gr.Button("Get Answer", variant="primary")
|
856 |
+
|
857 |
+
answer_output = gr.Textbox(
|
858 |
+
label="Phi-2's Response",
|
859 |
+
placeholder="The answer will appear here after you ask a question",
|
860 |
+
lines=8
|
861 |
+
)
|
862 |
+
|
863 |
+
gr.Markdown("""
|
864 |
+
### Example Questions
|
865 |
+
- "What is this document about?"
|
866 |
+
- "Summarize the key points in this document"
|
867 |
+
- "What does the document say about [specific topic]?"
|
868 |
+
- "Explain the relationship between [concept A] and [concept B]"
|
869 |
+
""")
|
870 |
+
|
871 |
+
# Set up events
|
872 |
+
load_model_btn.click(
|
873 |
+
fn=load_base_model,
|
874 |
+
outputs=[status]
|
875 |
+
)
|
876 |
+
|
877 |
+
process_btn.click(
|
878 |
+
fn=process_document,
|
879 |
+
inputs=[file_input],
|
880 |
+
outputs=[status]
|
881 |
+
)
|
882 |
+
|
883 |
+
ask_btn.click(
|
884 |
+
fn=generate_answer,
|
885 |
+
inputs=[question_input],
|
886 |
+
outputs=[answer_output]
|
887 |
+
)
|
888 |
+
|
889 |
+
# Launch the app
|
890 |
+
if __name__ == "__main__":
|
891 |
+
demo.launch(share=True)
|