Nick-Dev commited on
Commit
9a74106
·
verified ·
1 Parent(s): fd74278

Upload handler.py

Browse files

This is the handler.py file

Files changed (1) hide show
  1. handler.py +72 -0
handler.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ class ModelHandler:
5
+ def __init__(self):
6
+ self.initialized = False
7
+
8
+ def initialize(self, model_dir: str):
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
10
+ self.model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
11
+ self.model.eval()
12
+ if torch.cuda.is_available():
13
+ self.model.to("cuda")
14
+ self.initialized = True
15
+
16
+ def predict(self, inputs: dict):
17
+ if not self.initialized:
18
+ raise RuntimeError("Model not initialized")
19
+
20
+ messages = inputs.get("messages", [])
21
+ max_tokens = inputs.get("max_tokens", 512)
22
+ temperature = inputs.get("temperature", 0.7)
23
+
24
+ # Convert OpenAI-style messages into a single prompt
25
+ prompt = self._build_prompt(messages)
26
+
27
+ # Tokenize
28
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
29
+ if torch.cuda.is_available():
30
+ input_ids = input_ids.to("cuda")
31
+
32
+ # Generate
33
+ output_ids = self.model.generate(
34
+ input_ids,
35
+ max_new_tokens=max_tokens,
36
+ temperature=temperature,
37
+ do_sample=True,
38
+ pad_token_id=self.tokenizer.eos_token_id,
39
+ )
40
+
41
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
42
+
43
+ # Return just the newly generated portion
44
+ generated_text = response[len(prompt):].strip()
45
+
46
+ return {
47
+ "id": "chatcmpl-fakeid",
48
+ "object": "chat.completion",
49
+ "choices": [
50
+ {
51
+ "index": 0,
52
+ "message": {
53
+ "role": "assistant",
54
+ "content": generated_text
55
+ },
56
+ "finish_reason": "stop"
57
+ }
58
+ ],
59
+ "model": "your-model-id",
60
+ }
61
+
62
+ def _build_prompt(self, messages):
63
+ prompt = ""
64
+ for msg in messages:
65
+ role = msg["role"]
66
+ content = msg["content"]
67
+ if role == "user":
68
+ prompt += f"User: {content}\n"
69
+ elif role == "assistant":
70
+ prompt += f"Assistant: {content}\n"
71
+ prompt += "Assistant:"
72
+ return prompt