vkamra commited on
Commit
faeb8d3
·
verified ·
1 Parent(s): d00c4d2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -19
handler.py CHANGED
@@ -1,37 +1,50 @@
1
- from typing import Any, Dict, List
2
-
3
  import torch
4
  import transformers
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
-
7
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
8
-
9
-
10
  class EndpointHandler:
11
  def __init__(self, path="vkamra/llama_finetune_clockit"):
 
12
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
13
- model = AutoModelForCausalLM.from_pretrained(
14
- path,
15
- return_dict=True,
16
- device_map="auto",
17
- load_in_8bit=True,
18
- torch_dtype=dtype,
19
- trust_remote_code=True,
20
- )
21
-
 
 
 
 
 
 
 
 
 
 
 
 
22
  generation_config = model.generation_config
23
  generation_config.max_new_tokens = 60
24
- generation_config.temperature = 0
25
  generation_config.num_return_sequences = 1
26
  generation_config.pad_token_id = tokenizer.eos_token_id
27
  generation_config.eos_token_id = tokenizer.eos_token_id
28
  self.generation_config = generation_config
29
-
 
30
  self.pipeline = transformers.pipeline(
31
  "text-generation", model=model, tokenizer=tokenizer
32
  )
33
-
34
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
35
  prompt = data.pop("inputs", data)
36
  result = self.pipeline(prompt, generation_config=self.generation_config)
37
- return result
 
1
+ from typing import Any, Dict
 
2
  import torch
3
  import transformers
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ # Set dtype based on device capability
7
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 8 else torch.float16
8
+
9
  class EndpointHandler:
10
  def __init__(self, path="vkamra/llama_finetune_clockit"):
11
+ # Load tokenizer
12
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
13
+ tokenizer.padding_side = "left" # For proper padding alignment
14
+
15
+ # Load model with fallback for non-8bit environments
16
+ if torch.cuda.is_available():
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ path,
19
+ return_dict=True,
20
+ device_map="auto",
21
+ load_in_8bit=True,
22
+ torch_dtype=dtype,
23
+ trust_remote_code=True,
24
+ )
25
+ else:
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ path,
28
+ return_dict=True,
29
+ torch_dtype=torch.float32, # Full precision for CPU
30
+ trust_remote_code=True,
31
+ )
32
+
33
+ # Configure generation settings
34
  generation_config = model.generation_config
35
  generation_config.max_new_tokens = 60
36
+ generation_config.temperature = 0.7
37
  generation_config.num_return_sequences = 1
38
  generation_config.pad_token_id = tokenizer.eos_token_id
39
  generation_config.eos_token_id = tokenizer.eos_token_id
40
  self.generation_config = generation_config
41
+
42
+ # Initialize pipeline
43
  self.pipeline = transformers.pipeline(
44
  "text-generation", model=model, tokenizer=tokenizer
45
  )
46
+
47
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
48
  prompt = data.pop("inputs", data)
49
  result = self.pipeline(prompt, generation_config=self.generation_config)
50
+ return result