zake7749 commited on
Commit
135440a
·
1 Parent(s): 59d11c0

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +64 -0
README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - llama2
7
+ ---
8
+ This repository introduces a 4-bit quantized version of the [yayi-7b-llama2 model](https://huggingface.co/wenge-research/yayi-7b-llama2) proposed by [wenge-research](https://www.wenge.com/). The quantization process was performed using the [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ).
9
+
10
+ ## Usage Example
11
+
12
+ ```python
13
+ import torch
14
+ from auto_gptq import AutoGPTQForCausalLM
15
+ from transformers import LlamaTokenizer, GenerationConfig
16
+ from transformers import StoppingCriteria, StoppingCriteriaList
17
+
18
+ pretrained_model_name_or_path = "zake7749/yayi-7b-llama2-4bit-autogptq"
19
+ tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path)
20
+ model = AutoGPTQForCausalLM.from_quantized(pretrained_model_name_or_path)
21
+
22
+ # Define the stopping criteria
23
+ class KeywordsStoppingCriteria(StoppingCriteria):
24
+ def __init__(self, keywords_ids:list):
25
+ self.keywords = keywords_ids
26
+
27
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
28
+ if input_ids[0][-1] in self.keywords:
29
+ return True
30
+ return False
31
+
32
+ stop_words = ["<|End|>", "<|YaYi|>", "<|Human|>", "</s>"]
33
+ stop_ids = [tokenizer.encode(w)[-1] for w in stop_words]
34
+ stop_criteria = KeywordsStoppingCriteria(stop_ids)
35
+
36
+ # inference
37
+ prompt = "你是谁?"
38
+ formatted_prompt = f"""<|System|>:
39
+ You are a helpful, respectful and honest assistant named YaYi developed by Beijing Wenge Technology Co.,Ltd. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
40
+
41
+ <|Human|>:
42
+ {prompt}
43
+
44
+ <|YaYi|>:
45
+ """
46
+
47
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
48
+ eos_token_id = tokenizer("<|End|>").input_ids[0]
49
+ generation_config = GenerationConfig(
50
+ eos_token_id=eos_token_id,
51
+ pad_token_id=eos_token_id,
52
+ do_sample=True,
53
+ max_new_tokens=256,
54
+ temperature=0.3,
55
+ repetition_penalty=1.1,
56
+ no_repeat_ngram_size=0
57
+ )
58
+ response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
59
+ response = [response[0][len(inputs.input_ids[0]):]]
60
+ response_str = tokenizer.batch_decode(response, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
61
+ print(response_str)
62
+ ```
63
+
64
+ ## [License](https://github.com/wenge-research/YaYi/blob/main/LICENSE_MODEL)