jon-tow commited on
Commit
ab740b3
1 Parent(s): 3708810

fix: load 8-bit model

Browse files
Files changed (2) hide show
  1. app.py +4 -6
  2. requirements.txt +1 -0
app.py CHANGED
@@ -13,16 +13,14 @@ tokenizer = AutoTokenizer.from_pretrained(
13
  use_auth_token=auth_token if auth_token else True,
14
  )
15
  model = AutoModelForCausalLM.from_pretrained(
16
- "CarperAI/vicuna-13b-fine-tuned-rlhf-fp16",
17
- torch_dtype=torch.float16,
18
- device_map="auto",
19
- offload_folder="./offload",
20
- low_cpu_mem_usage=True, # Not required for demo but leave for now
21
  use_auth_token=auth_token if auth_token else True,
22
  )
23
  model.cuda()
 
 
24
  max_context_length = model.config.max_position_embeddings
25
- max_new_tokens = 500
26
 
27
 
28
  prompt_template = Template("""\
 
13
  use_auth_token=auth_token if auth_token else True,
14
  )
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ "CarperAI/vicuna-13b-fine-tuned-rlhf-8bit",
 
 
 
 
17
  use_auth_token=auth_token if auth_token else True,
18
  )
19
  model.cuda()
20
+
21
+
22
  max_context_length = model.config.max_position_embeddings
23
+ max_new_tokens = 512
24
 
25
 
26
  prompt_template = Template("""\
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  accelerate
2
  torch
 
3
  transformers>=4.28.0,<4.29.0
 
1
  accelerate
2
  torch
3
+ bitsandbytes
4
  transformers>=4.28.0,<4.29.0