SQuAD_Agent_Experiment / semscore.py
vonliechti's picture
Upload folder using huggingface_hub
dd5fe55 verified
from transformers import AutoTokenizer, AutoModel
from accelerate import Accelerator
from accelerate.utils import gather_object
from tqdm import tqdm
import torch, gc
import torch.nn as nn
class EmbeddingModelWrapper():
DEFAULT_MODEL="sentence-transformers/all-mpnet-base-v2"
def __init__(self, model_path=None, bs=8):
if model_path is None: model_path = self.DEFAULT_MODEL
self.model, self.tokenizer = self.load_model(model_path)
self.bs = bs
self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def load_model(self, model_path):
model = AutoModel.from_pretrained(
model_path,
).to("mps")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
return model, tokenizer
def emb_mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embeddings(self, sentences):
embeddings=torch.tensor([],device="mps")
if self.bs is None:
batches=[sentences]
else:
batches = [sentences[i:i + self.bs] for i in range(0, len(sentences), self.bs)]
for sentences in batches:
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to("mps")
with torch.no_grad():
model_output = self.model(**encoded_input)
batch_embeddings=self.emb_mean_pooling(model_output, encoded_input['attention_mask'])
embeddings=torch.cat( (embeddings, batch_embeddings), dim=0 )
return embeddings
def get_similarities(self, x, y=None):
if y is None:
num_samples=x.shape[0]
similarities = [[0 for i in range(num_samples)] for f in range(num_samples)]
for row in tqdm(range(num_samples)):
similarities[row][0:row+1]=self.cos(x[row].repeat(row+1,1), x[0:row+1]).tolist()
return similarities
else:
return self.cos(x,y).tolist()
class ModelPredictionGenerator:
def __init__(self, model, tokenizer, eval_dataset, use_accelerate=False, bs=8, generation_config=None):
self.model=model
self.tokenizer=tokenizer
self.bs=bs
self.eval_prompts=self.messages_to_prompts( eval_dataset )
self.use_accelerate=use_accelerate
self.accelerator = Accelerator()
assert tokenizer.eos_token_id is not None
assert tokenizer.chat_template is not None
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# llama-precise
if generation_config is None:
self.generation_config = {
"temperature": 0.7,
"top_p": 0.1,
"repetition_penalty": 1.18,
"top_k": 40,
"do_sample": True,
"max_new_tokens": 100,
"pad_token_id": tokenizer.pad_token_id
}
else:
self.generation_config = generation_config
def clear_cache(self):
torch.mps.empty_cache()
gc.collect()
def messages_to_prompts(self, ds):
prompts=[]
for conversation in ds["messages"]:
for i,msg in enumerate(conversation):
if msg["role"]=="user":
prompts.append(
dict (
# prompt: format current messages up to the current user message and add a generation prompt
prompt=self.tokenizer.apply_chat_template(conversation[:i+1], add_generation_prompt=True, tokenize=False),
answer_ref=conversation[i+1]["content"]
)
)
return prompts
def get_batches(self, dataset, batch_size):
return [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)]
def tokenize_batch(self, batch):
pad_side=self.tokenizer.padding_side
self.tokenizer.padding_side="left" # left pad for inference
prompts=[ item["prompt"] for item in batch ]
prompts_tok=self.tokenizer(
prompts,
return_tensors="pt",
padding='longest',
truncation=True,
max_length=self.tokenizer.model_max_length,
return_length=True,
pad_to_multiple_of=8,
add_special_tokens=False
).to(self.model.device)
self.tokenizer.padding_side=pad_side # restore orig. padding side
return prompts_tok
def generate_batch(self, batch_tok):
with torch.no_grad():
outputs_tok=self.model.generate(
input_ids=batch_tok["input_ids"],
attention_mask=batch_tok["attention_mask"],
**self.generation_config
).to("cpu")
outputs=[
# cut prompt from output
self.tokenizer.decode(
outputs_tok[i][outputs_tok[i] != self.tokenizer.pad_token_id][batch_tok["length"][i]:],
spaces_between_special_tokens=False,
skip_special_tokens=True
).strip()
for i,t in enumerate(outputs_tok) ]
return outputs
def run(self):
self.model.eval()
self.clear_cache()
if self.use_accelerate:
with self.accelerator.split_between_processes(list(range(len(self.eval_prompts)))) as eval_prompts_local_idcs:
eval_prompts_local = [self.eval_prompts[i] for i in eval_prompts_local_idcs]
else:
eval_prompts_local = self.eval_prompts
for batch in tqdm( self.get_batches(eval_prompts_local, self.bs) ):
batch_tok = self.tokenize_batch( batch )
answers = self.generate_batch( batch_tok )
for i in range(len(batch)):
batch[i]["answer_pred"]=answers[i]
batch[i]["GPU"]=self.accelerator.process_index
if self.use_accelerate:
return gather_object(eval_prompts_local)
else:
return eval_prompts_local