from transformers import AutoModelForCausalLM, AutoTokenizer from .templates import make_prompt, check_template_name, print_templates, supported_templates from .config import model_dict from torch.nn import CrossEntropyLoss LOSS_FCT = CrossEntropyLoss(reduction='mean') class ppluie: def __init__( self, model, device = "cuda:0", template = "FS-DIRECT", use_chat_template = True, half_mode = True, n_right_specials_tokens = 1 ): self.device = device self.use_chat_tmplt = use_chat_template # assert model in supported_models, f""+model+" not support, supported models: "+str(supported_models) if model not in model_dict.keys(): print("You call ParaPLUIE with "+model+".\nParaPLUIE has been tested with "+str(model_dict.keys())+"\n Using it with another model could lead to unexpected behaviour.") self.model = AutoModelForCausalLM.from_pretrained(model) self.tokenizer = AutoTokenizer.from_pretrained(model, padding_side='left') self.n_right_special_tokens = n_right_specials_tokens # use configs else: self.model = AutoModelForCausalLM.from_pretrained( model_dict[model]["path"], trust_remote_code = model_dict[model]["trust_remote_code"] ) self.tokenizer = AutoTokenizer.from_pretrained( model_dict[model]["path"], trust_remote_code = model_dict[model]["trust_remote_code"], padding_side = 'left' ) self.n_right_special_tokens = model_dict[model]["r_spe_tokens"] self.use_chat_tmplt = model_dict[model]["use_chat_tmplt"] if half_mode: self.model = self.model.half() self.model = self.model.eval() self.model = self.model.to(self.device) self.setTemplate(template) def show_templates(self): print_templates() def show_available_models(self): print("LLM tested with PPLUIE: ") for k in (model_dict.keys()): print(k) def setTemplate(self, template: str): check_template_name(template) self.template = template def stringify_prompt(self, prompt): stringify_prompt = "" for v in prompt[:-1]: stringify_prompt += v["content"] + " " stringify_prompt += prompt[-1]["content"] return stringify_prompt def vraisemblance(self, promptY, promptN): # opti une seul inférence #check model que guillaume a pas fait if self.use_chat_tmplt: input_model = self.tokenizer.apply_chat_template(promptY, return_tensors="pt", padding=False).to(self.device)[:,:-self.n_right_special_tokens] encodedsY = self.tokenizer.apply_chat_template(promptY, return_tensors="pt", padding=False).to(self.device) encodedsN = self.tokenizer.apply_chat_template(promptN, return_tensors="pt", padding=False).to(self.device) #on supprime le/s token de fin de sentence encodedsY = encodedsY[:,1:-self.n_right_special_tokens] encodedsN = encodedsN[:,1:-self.n_right_special_tokens] else: input_model = self.tokenizer(self.stringify_prompt(promptY), return_tensors="pt", padding=False)["input_ids"].to(self.device) encodedsY = self.tokenizer(self.stringify_prompt(promptY), return_tensors="pt", padding=False)["input_ids"].to(self.device) encodedsN = self.tokenizer(self.stringify_prompt(promptN), return_tensors="pt", padding=False)["input_ids"].to(self.device) #on supprime le/s token de fin de sentence encodedsY = encodedsY[:,1:] encodedsN = encodedsN[:,1:] generate_ids = self.model( input_ids=input_model[:,:-1], return_dict=True ) n_tokens = len(input_model[0]) generate_ids = generate_ids["logits"].squeeze().float() loss_yes = LOSS_FCT(generate_ids, encodedsY.view(-1)) * n_tokens loss_no = LOSS_FCT(generate_ids, encodedsN.view(-1)) * n_tokens loss = loss_no.item() - loss_yes.item() return loss def chech_end_tokens_tmpl(self): prompt_yes, prompt_no = make_prompt( self.template, "this is a test", "this is a test", self.model, # if intermediate generation is needed self.tokenizer, # if intermediate generation is needed self.device # if intermediate generation is needed ) if self.use_chat_tmplt: enc = self.tokenizer.apply_chat_template(prompt_yes, return_tensors="pt", padding=False)[0][-10:] else: enc = self.tokenizer(self.stringify_prompt(prompt_yes), return_tensors="pt", padding=False)["input_ids"][0][-10:] print("Yes prompt:") print(enc) for t in enc: print(t, " - ", self.tokenizer.decode(t)) if self.use_chat_tmplt: enc = self.tokenizer.apply_chat_template(prompt_no, return_tensors="pt", padding=False)[0][-10:] else: enc = self.tokenizer(self.stringify_prompt(prompt_no), return_tensors="pt", padding=False)["input_ids"][0][-10:] print("No prompt:") print(enc) for t in enc: print(t, " - ", self.tokenizer.decode(t)) def __call__(self, reference, hypothese, logger=None): prompt_yes, prompt_no = make_prompt( self.template, reference, hypothese, self.model, # if intermediate generation is needed self.tokenizer, # if intermediate generation is needed self.device # if intermediate generation is needed ) score = self.vraisemblance( promptY=prompt_yes, promptN=prompt_no) # self.chech_end_tokens_tmpl(prompt_yes) if logger: logger.info('Start entrie') logger.info('Template: '+self.template) logger.info('Yes prompt : '+str(prompt_yes)) logger.info('Yes vraisemblance : '+str(self.vraisemblance(prompt_yes))) logger.info('No prompt : '+str(prompt_no)) logger.info('No vraisemblance : '+str(self.vraisemblance(prompt_no))) logger.info('Score : '+str(score)) logger.info('End entrie') return score def get_all_templates(self): return supported_templates