matt-bcny commited on
Commit
e8771ff
·
verified ·
1 Parent(s): 1b8d8e5

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +460 -0
handler.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Union
2
+ import os
3
+ import json
4
+ import time
5
+ import torch
6
+ from threading import Thread
7
+ import logging
8
+ from transformers import (
9
+ AutoTokenizer,
10
+ AutoModelForCausalLM,
11
+ TextIteratorStreamer,
12
+ StoppingCriteriaList,
13
+ StoppingCriteria,
14
+ BitsAndBytesConfig
15
+ )
16
+ from peft import PeftModel
17
+
18
+ # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(levelname)s - %(message)s',
22
+ handlers=[
23
+ logging.FileHandler("lora_inference.log"),
24
+ logging.StreamHandler()
25
+ ]
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class ImprovedJSONStoppingCriteria(StoppingCriteria):
30
+ """
31
+ Stopping criteria that ensures JSON is complete before stopping.
32
+ Only stops generation when a valid, complete JSON object is detected.
33
+ """
34
+ def __init__(self, tokenizer):
35
+ self.tokenizer = tokenizer
36
+ self.generated = ""
37
+ self.json_complete = False
38
+
39
+ def __call__(self, input_ids, scores, **kwargs):
40
+ # If we already found complete JSON, stop immediately
41
+ if self.json_complete:
42
+ return True
43
+
44
+ # Decode current text
45
+ text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
46
+
47
+ # Skip early if no JSON structure detected
48
+ if '{' not in text:
49
+ return False
50
+
51
+ # Don't stop if we don't have at least one closing brace
52
+ if '}' not in text:
53
+ return False
54
+
55
+ # Check for complete JSON structure
56
+ try:
57
+ # First, try to find a valid JSON object
58
+ start_pos = text.find('{')
59
+
60
+ # Progressively validate from the first opening brace
61
+ stack = []
62
+ end_pos = -1
63
+
64
+ for i, char in enumerate(text[start_pos:], start_pos):
65
+ if char == '{':
66
+ stack.append('{')
67
+ elif char == '}':
68
+ if stack:
69
+ stack.pop()
70
+ if not stack: # We have balanced braces
71
+ end_pos = i
72
+ potential_json = text[start_pos:end_pos+1]
73
+
74
+ # Make sure this is actually valid JSON
75
+ # and not just balanced braces
76
+ try:
77
+ # Parse JSON to validate
78
+ parsed = json.loads(potential_json)
79
+
80
+ # We need to make sure we have all required fields
81
+ # For search_web or tool calls, verify arguments are complete
82
+ if "calls" in parsed:
83
+ for call in parsed.get("calls", []):
84
+ # If we have a call with arguments, make sure they're complete
85
+ if "arguments" in call:
86
+ args = call.get("arguments", "")
87
+
88
+ # If arguments is a string, it might be JSON itself
89
+ if isinstance(args, str) and args.startswith("{"):
90
+ # If the argument string starts with { but doesn't have a
91
+ # closing }, it's incomplete
92
+ if not args.endswith("}"):
93
+ return False
94
+
95
+ # Try to parse the arguments as JSON
96
+ try:
97
+ json.loads(args)
98
+ except:
99
+ # If we can't parse, the JSON is incomplete
100
+ return False
101
+
102
+ # All checks passed - we have valid, complete JSON
103
+ self.json_complete = True
104
+ return True
105
+ except:
106
+ # Not valid JSON, continue looking
107
+ continue
108
+
109
+ # Only stop with excessive braces if we already have a valid structure
110
+ open_count = text.count('{')
111
+ close_count = text.count('}')
112
+ if close_count > open_count:
113
+ # Check if we have a valid JSON by balancing
114
+ fixed_text = text[start_pos:]
115
+ stack = []
116
+ for i, char in enumerate(fixed_text):
117
+ if char == '{':
118
+ stack.append('{')
119
+ elif char == '}':
120
+ if stack:
121
+ stack.pop()
122
+ if not stack:
123
+ try:
124
+ potential_json = fixed_text[:i+1]
125
+ parsed = json.loads(potential_json)
126
+ self.json_complete = True
127
+ return True
128
+ except:
129
+ pass
130
+ except Exception:
131
+ # Error in parsing or validation, don't stop
132
+ pass
133
+
134
+ return False
135
+
136
+ class ExcessBraceStoppingCriteria(StoppingCriteria):
137
+ """Stop generation if we're generating excessive closing braces"""
138
+ def __init__(self, tokenizer):
139
+ self.tokenizer = tokenizer
140
+
141
+ def __call__(self, input_ids, scores, **kwargs):
142
+ text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
143
+
144
+ # Only trigger if we have JSON content
145
+ if '{' in text and '}' in text:
146
+ # Check if we're generating excessive closing braces
147
+ open_count = text.count('{')
148
+ close_count = text.count('}')
149
+
150
+ # If we have more closing than opening braces, stop generation
151
+ if close_count > open_count + 3: # Allow a small buffer
152
+ return True
153
+
154
+ return False
155
+
156
+ def fix_json_output(text):
157
+ """Fix malformed JSON with excessive closing braces."""
158
+ if '{' not in text or '}' not in text:
159
+ return text
160
+
161
+ # Count opening and closing braces
162
+ open_count = text.count('{')
163
+ close_count = text.count('}')
164
+
165
+ # If balanced or too few closing braces, return as-is
166
+ if open_count >= close_count:
167
+ return text
168
+
169
+ # Track JSON depth to find valid JSON object
170
+ start_pos = text.find('{')
171
+ depth = 0
172
+ for i, char in enumerate(text[start_pos:], start_pos):
173
+ if char == '{':
174
+ depth += 1
175
+ elif char == '}':
176
+ depth -= 1
177
+ if depth == 0:
178
+ # Found balanced JSON, return up to this point
179
+ return text[:i+1]
180
+
181
+ # If we can't balance it with depth tracking, simply truncate
182
+ return text[:start_pos + text[start_pos:].find('}')+1]
183
+
184
+ def create_stopping_criteria(tokenizer, stop_tokens):
185
+ """Create stopping criteria from tokens"""
186
+ stop_token_ids = []
187
+ for stop_token in stop_tokens:
188
+ token_ids = tokenizer.encode(stop_token, add_special_tokens=False)
189
+ if len(token_ids) > 0:
190
+ stop_token_ids.append(token_ids[-1])
191
+
192
+ return StoppingCriteriaList([StopOnTokens(tokenizer, stop_token_ids)])
193
+
194
+ class StopOnTokens(StoppingCriteria):
195
+ """Custom stopping criteria for text generation."""
196
+ def __init__(self, tokenizer, stop_token_ids):
197
+ self.tokenizer = tokenizer
198
+ self.stop_token_ids = stop_token_ids
199
+
200
+ def __call__(self, input_ids, scores, **kwargs):
201
+ for stop_id in self.stop_token_ids:
202
+ if input_ids[0][-1] == stop_id:
203
+ return True
204
+ return False
205
+
206
+ class EndpointHandler:
207
+ def __init__(self, path=""):
208
+ """
209
+ Initialize the handler by loading model and tokenizer
210
+
211
+ Args:
212
+ path (str): Path to the model directory (uses environment variable if not provided)
213
+ """
214
+ # Get model path from environment or from argument
215
+ model_path = path if path else os.environ.get("MODEL_PATH", "")
216
+ adapter_path = os.environ.get("ADAPTER_PATH", None)
217
+ logger.info(f"Loading model from {model_path}")
218
+
219
+ # Determine quantization settings from environment
220
+ use_8bit = os.environ.get("USE_8BIT", "False").lower() == "true"
221
+ use_4bit = os.environ.get("USE_4BIT", "False").lower() == "true"
222
+ device = os.environ.get("DEVICE", "auto")
223
+
224
+ # Load tokenizer
225
+ logger.info(f"Loading tokenizer from {model_path}")
226
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
227
+ if self.tokenizer.pad_token is None:
228
+ self.tokenizer.pad_token = self.tokenizer.eos_token
229
+
230
+ # Load model with appropriate configuration
231
+ if use_4bit:
232
+ logger.info("Using 4-bit quantization for inference...")
233
+ quantization_config = BitsAndBytesConfig(
234
+ load_in_4bit=True,
235
+ bnb_4bit_compute_dtype=torch.float16,
236
+ bnb_4bit_use_double_quant=True,
237
+ bnb_4bit_quant_type="nf4"
238
+ )
239
+ base_model = AutoModelForCausalLM.from_pretrained(
240
+ model_path,
241
+ quantization_config=quantization_config,
242
+ device_map=device,
243
+ low_cpu_mem_usage=True
244
+ )
245
+ elif use_8bit:
246
+ logger.info("Using 8-bit quantization for inference...")
247
+ base_model = AutoModelForCausalLM.from_pretrained(
248
+ model_path,
249
+ load_in_8bit=True,
250
+ device_map=device,
251
+ low_cpu_mem_usage=True
252
+ )
253
+ else:
254
+ logger.info("Loading model in float16 precision...")
255
+ base_model = AutoModelForCausalLM.from_pretrained(
256
+ model_path,
257
+ torch_dtype=torch.float16,
258
+ device_map=device,
259
+ low_cpu_mem_usage=True
260
+ )
261
+
262
+ # Apply adapter if specified
263
+ if adapter_path:
264
+ logger.info(f"Loading LoRA adapter from {adapter_path}")
265
+ self.model = PeftModel.from_pretrained(base_model, adapter_path)
266
+ else:
267
+ self.model = base_model
268
+ logger.info("No adapter path provided, using base model only")
269
+
270
+ self.model.eval()
271
+
272
+ # Try to use torch.compile for additional performance if available
273
+ if torch.__version__ >= "2.0.0" and os.environ.get("USE_COMPILE", "False").lower() == "true":
274
+ try:
275
+ logger.info("Applying torch.compile for additional optimization...")
276
+ self.model = torch.compile(self.model)
277
+ logger.info("Model successfully compiled!")
278
+ except Exception as e:
279
+ logger.warning(f"Could not compile model: {e}")
280
+
281
+ logger.info("Model and tokenizer loaded successfully!")
282
+
283
+ def format_conversation(self, messages, add_generation_prompt=True):
284
+ """Format a conversation using the tokenizer's chat template"""
285
+ return self.tokenizer.apply_chat_template(
286
+ messages,
287
+ tokenize=False,
288
+ add_generation_prompt=add_generation_prompt
289
+ )
290
+
291
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
292
+ """
293
+ Process inference request
294
+
295
+ Args:
296
+ data (Dict[str, Any]): Request data containing inputs and parameters
297
+
298
+ Returns:
299
+ List[Dict[str, Any]]: List of response dictionaries
300
+ """
301
+ start_time = time.time()
302
+
303
+ # Extract input data and parameters
304
+ inputs = data.get("inputs", [])
305
+ parameters = data.get("parameters", {})
306
+
307
+ # Parse generation parameters with defaults
308
+ max_new_tokens = parameters.get("max_new_tokens", 512)
309
+ temperature = parameters.get("temperature", 0.7)
310
+ top_p = parameters.get("top_p", 0.95)
311
+ do_sample = parameters.get("do_sample", temperature > 0.1)
312
+ stream = parameters.get("stream", False)
313
+ json_mode = parameters.get("json_mode", False)
314
+ system_prompt = parameters.get("system_prompt", None)
315
+
316
+ # Check if input is in various formats and normalize to messages format
317
+ if isinstance(inputs, str):
318
+ # Create simple chat with user message
319
+ messages = [{"role": "user", "content": inputs}]
320
+ elif isinstance(inputs, dict) and "messages" in inputs:
321
+ # Input is already in chat format
322
+ messages = inputs["messages"]
323
+ elif isinstance(inputs, list):
324
+ # Assume this is a list of message dicts
325
+ messages = inputs
326
+ else:
327
+ # Invalid input format
328
+ return [{"error": "Invalid input format. Please provide a string, a list of messages, or a dict with 'messages' key."}]
329
+
330
+ # Prepare conversation with system prompt if provided
331
+ conversation = []
332
+ if system_prompt:
333
+ conversation.append({"role": "system", "content": system_prompt})
334
+ conversation.extend(messages)
335
+
336
+ # Format the conversation
337
+ prompt = self.format_conversation(conversation)
338
+
339
+ # Tokenize the prompt
340
+ inputs_dict = self.tokenizer(prompt, return_tensors="pt")
341
+ inputs_dict = {k: v.to(self.model.device) for k, v in inputs_dict.items()}
342
+
343
+ # Configure generation parameters
344
+ generation_config = {
345
+ "max_new_tokens": max_new_tokens,
346
+ "temperature": temperature,
347
+ "top_p": top_p,
348
+ "do_sample": do_sample,
349
+ "pad_token_id": self.tokenizer.pad_token_id,
350
+ }
351
+
352
+ # Add JSON-specific settings if needed
353
+ if json_mode:
354
+ stop_tokens = ["\n\n", "\n}", "}\n", "}}", "} }", "}\n]", "}\n{"]
355
+ stopping_criteria = create_stopping_criteria(self.tokenizer, stop_tokens)
356
+ generation_config["stopping_criteria"] = stopping_criteria
357
+
358
+ # Lower temperature for JSON mode to get more reliable outputs
359
+ # but don't set to 0 as that might cause truncation issues
360
+ temperature = min(temperature, 0.1)
361
+ do_sample = False
362
+ generation_config["do_sample"] = do_sample
363
+ generation_config["temperature"] = temperature
364
+
365
+ # Record input length for proper decoding
366
+ input_length = inputs_dict["input_ids"].shape[1]
367
+
368
+ generated_text = ""
369
+ if stream:
370
+ # Use streaming for interactive responses
371
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
372
+ generation_config["streamer"] = streamer
373
+
374
+ # Start generation in a thread
375
+ thread = Thread(target=self.model.generate, kwargs={**inputs_dict, **generation_config})
376
+ thread.start()
377
+
378
+ # Stream the output (for local testing)
379
+ for text in streamer:
380
+ generated_text += text
381
+
382
+ # Apply JSON cleaning if needed and json_mode is enabled
383
+ if json_mode and '{' in generated_text and '}' in generated_text:
384
+ if generated_text.count('}') > generated_text.count('{'):
385
+ fixed_text = fix_json_output(generated_text)
386
+ if fixed_text != generated_text:
387
+ logger.info("Fixed malformed JSON in response")
388
+ generated_text = fixed_text
389
+ else:
390
+ # Non-streaming generation
391
+ with torch.no_grad():
392
+ outputs = self.model.generate(**inputs_dict, **generation_config)
393
+
394
+ # Decode the output
395
+ generated_ids = outputs[0][input_length:]
396
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
397
+
398
+ # Apply JSON cleaning if needed and json_mode is enabled
399
+ if json_mode and '{' in generated_text and '}' in generated_text:
400
+ if generated_text.count('}') > generated_text.count('{'):
401
+ fixed_text = fix_json_output(generated_text)
402
+ if fixed_text != generated_text:
403
+ logger.info("Fixed malformed JSON in response")
404
+ generated_text = fixed_text
405
+
406
+ # Calculate processing time
407
+ end_time = time.time()
408
+ processing_time = end_time - start_time
409
+
410
+ # Create response dictionary
411
+ response = {
412
+ "generated_text": generated_text,
413
+ "processing_time": processing_time
414
+ }
415
+
416
+ # Include input token count if requested
417
+ if parameters.get("return_token_count", False):
418
+ response["input_token_count"] = input_length
419
+ response["output_token_count"] = len(generated_text.split())
420
+
421
+ return [response]
422
+
423
+ # For local testing
424
+ if __name__ == "__main__":
425
+ # Test the handler
426
+ model_path = os.environ.get("MODEL_PATH", "./model")
427
+ handler = EndpointHandler(model_path)
428
+
429
+ # Test with a simple query
430
+ test_data = {
431
+ "inputs": "Explain the concept of machine learning in simple terms.",
432
+ "parameters": {
433
+ "max_new_tokens": 100,
434
+ "temperature": 0.7,
435
+ "system_prompt": "You are a helpful AI assistant."
436
+ }
437
+ }
438
+
439
+ response = handler(test_data)
440
+ print("\nTest Response:")
441
+ print(json.dumps(response, indent=2))
442
+
443
+ # Test with chat format and JSON mode
444
+ test_chat_data = {
445
+ "inputs": {
446
+ "messages": [
447
+ {"role": "user", "content": "Create a JSON object with information about the solar system. Include at least 3 planets with their name, diameter, and distance from the sun."}
448
+ ]
449
+ },
450
+ "parameters": {
451
+ "max_new_tokens": 512,
452
+ "temperature": 0.1,
453
+ "json_mode": True,
454
+ "system_prompt": "You are a helpful AI assistant that responds in JSON format."
455
+ }
456
+ }
457
+
458
+ chat_response = handler(test_chat_data)
459
+ print("\nJSON Format Response:")
460
+ print(json.dumps(chat_response, indent=2))