File size: 2,184 Bytes
2edd118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, get_peft_model
from peft import PeftModel, PeftConfig
from .BaseLLM import BaseLLM
import torch 

tokenizer_GLM = None
model_GLM = None

def initialize_GLM2LORA():
    pass
    global tokenizer_GLM
    global model_GLM

    if tokenizer_GLM == None and model_GLM == None:
        tokenizer_GLM = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
        model_GLM = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).half().cuda()

        config = LoraConfig(
            r=16,
            lora_alpha=32,
            inference_mode=True,
            lora_dropout=0.05,
            #bias="none",
            task_type="CAUSAL_LM"
        )

        model_GLM = PeftModel.from_pretrained(model_GLM, "silk-road/Chat-Haruhi-Fusion_B")
    return model_GLM, tokenizer_GLM

def GLM_tokenizer(text):
    return len(tokenizer_GLM.encode(text))

class ChatGLM2GPT(BaseLLM):
    def __init__(self, model = "haruhi-fusion"):
        super(ChatGLM2GPT, self).__init__()
        if model == "glm2-6b":
            self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
            self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).half().cuda()
        if model == "haruhi-fusion":
            self.model, self.tokenizer = initialize_GLM2LORA()
        else:
            raise Exception("Unknown GLM model")
        self.messages = ""

    def initialize_message(self):
        self.message = ""

    def ai_message(self, payload):
        self.messages = self.messages + "\n " + payload 

    def system_message(self, payload):
        self.messages = self.messages + "\n " + payload 

    def user_message(self, payload):
        self.messages = self.messages + "\n " + payload 

    def get_response(self):
        with torch.no_grad():
            response, history = self.model.chat(self.tokenizer, self.messages, history=[])
            # print(response)
        return response
        
    def print_prompt(self):
        print(type(self.messages))
        print(self.messages)