Spaces:
Runtime error
Runtime error
File size: 2,184 Bytes
2edd118 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, get_peft_model
from peft import PeftModel, PeftConfig
from .BaseLLM import BaseLLM
import torch
tokenizer_GLM = None
model_GLM = None
def initialize_GLM2LORA():
pass
global tokenizer_GLM
global model_GLM
if tokenizer_GLM == None and model_GLM == None:
tokenizer_GLM = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model_GLM = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).half().cuda()
config = LoraConfig(
r=16,
lora_alpha=32,
inference_mode=True,
lora_dropout=0.05,
#bias="none",
task_type="CAUSAL_LM"
)
model_GLM = PeftModel.from_pretrained(model_GLM, "silk-road/Chat-Haruhi-Fusion_B")
return model_GLM, tokenizer_GLM
def GLM_tokenizer(text):
return len(tokenizer_GLM.encode(text))
class ChatGLM2GPT(BaseLLM):
def __init__(self, model = "haruhi-fusion"):
super(ChatGLM2GPT, self).__init__()
if model == "glm2-6b":
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).half().cuda()
if model == "haruhi-fusion":
self.model, self.tokenizer = initialize_GLM2LORA()
else:
raise Exception("Unknown GLM model")
self.messages = ""
def initialize_message(self):
self.message = ""
def ai_message(self, payload):
self.messages = self.messages + "\n " + payload
def system_message(self, payload):
self.messages = self.messages + "\n " + payload
def user_message(self, payload):
self.messages = self.messages + "\n " + payload
def get_response(self):
with torch.no_grad():
response, history = self.model.chat(self.tokenizer, self.messages, history=[])
# print(response)
return response
def print_prompt(self):
print(type(self.messages))
print(self.messages)
|