ericnunes1 commited on
Commit
0d6bfde
·
verified ·
1 Parent(s): c0ed359

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +79 -0
handler.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
+ from peft import PeftModel
3
+ import torch
4
+ import os
5
+ import traceback
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ base_model_id = "microsoft/Phi-4-mini-instruct"
10
+ adapter_path = path # Diretório local no container onde o repo foi baixado
11
+
12
+ try:
13
+ print(f"Iniciando Handler: Carregando modelo base {base_model_id}")
14
+ # Carregar em bfloat16 ou float16 se disponível
15
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
16
+ self.base_model = AutoModelForCausalLM.from_pretrained(
17
+ base_model_id,
18
+ torch_dtype=dtype,
19
+ trust_remote_code=True
20
+ # device_map é gerenciado pelo endpoint
21
+ )
22
+
23
+ print(f"Carregando tokenizer de {base_model_id}")
24
+ self.tokenizer = AutoTokenizer.from_pretrained(
25
+ base_model_id,
26
+ trust_remote_code=True
27
+ )
28
+ if self.tokenizer.pad_token is None or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
29
+ self.tokenizer.pad_token = self.tokenizer.unk_token
30
+ print("Definido tokenizer.pad_token = tokenizer.unk_token")
31
+
32
+ print(f"Carregando adaptador LoRA de {adapter_path}")
33
+ self.model = PeftModel.from_pretrained(self.base_model, adapter_path)
34
+ self.model.eval()
35
+ print("Adaptador LoRA carregado.")
36
+
37
+ self.pipeline = pipeline(
38
+ "text-generation",
39
+ model=self.model,
40
+ tokenizer=self.tokenizer,
41
+ # device=0 # Geralmente não necessário, endpoint gerencia
42
+ )
43
+ print("Pipeline de text-generation criado. Handler pronto.")
44
+
45
+ except Exception as e:
46
+ print(f"ERRO FATAL durante __init__ do Handler: {e}")
47
+ print(traceback.format_exc())
48
+ raise e # Levanta o erro para falhar a inicialização do endpoint
49
+
50
+
51
+ def __call__(self, data):
52
+ try:
53
+ inputs = data.pop("inputs", data)
54
+ parameters = data.pop("parameters", None) or {}
55
+
56
+ print(f"Handler __call__ recebeu inputs: {inputs}")
57
+ print(f"Handler __call__ recebeu parâmetros: {parameters}")
58
+
59
+ # Preparar o prompt - Detecta se input é lista de dicts (chat) ou string
60
+ prompt_text = inputs
61
+ if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict) and 'role' in inputs[0]:
62
+ print("Aplicando chat template...")
63
+ # Cuidado: add_generation_prompt=True é para gerar RESPOSTA do assistant
64
+ prompt_text = self.tokenizer.apply_chat_template(inputs, tokenize=False, add_generation_prompt=True)
65
+
66
+ print(f"Texto do prompt para o pipeline: {prompt_text}")
67
+
68
+ # Gerar texto usando o pipeline
69
+ outputs = self.pipeline(prompt_text, **parameters)
70
+
71
+ print(f"Handler __call__ gerou outputs: {outputs}")
72
+ # Retorna a saída (geralmente uma lista de dicionários)
73
+ return outputs
74
+
75
+ except Exception as e:
76
+ print(f"ERRO durante __call__ do Handler: {e}")
77
+ print(traceback.format_exc())
78
+ # Retornar erro de forma estruturada ajuda na depuração
79
+ return [{"error": str(e), "traceback": traceback.format_exc()}]