File size: 6,460 Bytes
dd5fe55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from transformers import AutoTokenizer, AutoModel
from accelerate import Accelerator
from accelerate.utils import gather_object
from tqdm import tqdm
import torch, gc
import torch.nn as nn
class EmbeddingModelWrapper():
DEFAULT_MODEL="sentence-transformers/all-mpnet-base-v2"
def __init__(self, model_path=None, bs=8):
if model_path is None: model_path = self.DEFAULT_MODEL
self.model, self.tokenizer = self.load_model(model_path)
self.bs = bs
self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def load_model(self, model_path):
model = AutoModel.from_pretrained(
model_path,
).to("mps")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
return model, tokenizer
def emb_mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embeddings(self, sentences):
embeddings=torch.tensor([],device="mps")
if self.bs is None:
batches=[sentences]
else:
batches = [sentences[i:i + self.bs] for i in range(0, len(sentences), self.bs)]
for sentences in batches:
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to("mps")
with torch.no_grad():
model_output = self.model(**encoded_input)
batch_embeddings=self.emb_mean_pooling(model_output, encoded_input['attention_mask'])
embeddings=torch.cat( (embeddings, batch_embeddings), dim=0 )
return embeddings
def get_similarities(self, x, y=None):
if y is None:
num_samples=x.shape[0]
similarities = [[0 for i in range(num_samples)] for f in range(num_samples)]
for row in tqdm(range(num_samples)):
similarities[row][0:row+1]=self.cos(x[row].repeat(row+1,1), x[0:row+1]).tolist()
return similarities
else:
return self.cos(x,y).tolist()
class ModelPredictionGenerator:
def __init__(self, model, tokenizer, eval_dataset, use_accelerate=False, bs=8, generation_config=None):
self.model=model
self.tokenizer=tokenizer
self.bs=bs
self.eval_prompts=self.messages_to_prompts( eval_dataset )
self.use_accelerate=use_accelerate
self.accelerator = Accelerator()
assert tokenizer.eos_token_id is not None
assert tokenizer.chat_template is not None
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# llama-precise
if generation_config is None:
self.generation_config = {
"temperature": 0.7,
"top_p": 0.1,
"repetition_penalty": 1.18,
"top_k": 40,
"do_sample": True,
"max_new_tokens": 100,
"pad_token_id": tokenizer.pad_token_id
}
else:
self.generation_config = generation_config
def clear_cache(self):
torch.mps.empty_cache()
gc.collect()
def messages_to_prompts(self, ds):
prompts=[]
for conversation in ds["messages"]:
for i,msg in enumerate(conversation):
if msg["role"]=="user":
prompts.append(
dict (
# prompt: format current messages up to the current user message and add a generation prompt
prompt=self.tokenizer.apply_chat_template(conversation[:i+1], add_generation_prompt=True, tokenize=False),
answer_ref=conversation[i+1]["content"]
)
)
return prompts
def get_batches(self, dataset, batch_size):
return [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)]
def tokenize_batch(self, batch):
pad_side=self.tokenizer.padding_side
self.tokenizer.padding_side="left" # left pad for inference
prompts=[ item["prompt"] for item in batch ]
prompts_tok=self.tokenizer(
prompts,
return_tensors="pt",
padding='longest',
truncation=True,
max_length=self.tokenizer.model_max_length,
return_length=True,
pad_to_multiple_of=8,
add_special_tokens=False
).to(self.model.device)
self.tokenizer.padding_side=pad_side # restore orig. padding side
return prompts_tok
def generate_batch(self, batch_tok):
with torch.no_grad():
outputs_tok=self.model.generate(
input_ids=batch_tok["input_ids"],
attention_mask=batch_tok["attention_mask"],
**self.generation_config
).to("cpu")
outputs=[
# cut prompt from output
self.tokenizer.decode(
outputs_tok[i][outputs_tok[i] != self.tokenizer.pad_token_id][batch_tok["length"][i]:],
spaces_between_special_tokens=False,
skip_special_tokens=True
).strip()
for i,t in enumerate(outputs_tok) ]
return outputs
def run(self):
self.model.eval()
self.clear_cache()
if self.use_accelerate:
with self.accelerator.split_between_processes(list(range(len(self.eval_prompts)))) as eval_prompts_local_idcs:
eval_prompts_local = [self.eval_prompts[i] for i in eval_prompts_local_idcs]
else:
eval_prompts_local = self.eval_prompts
for batch in tqdm( self.get_batches(eval_prompts_local, self.bs) ):
batch_tok = self.tokenize_batch( batch )
answers = self.generate_batch( batch_tok )
for i in range(len(batch)):
batch[i]["answer_pred"]=answers[i]
batch[i]["GPU"]=self.accelerator.process_index
if self.use_accelerate:
return gather_object(eval_prompts_local)
else:
return eval_prompts_local |