|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """Perplexity Metric.""" | 
					
						
						|  |  | 
					
						
						|  | import datasets | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | from torch.nn import CrossEntropyLoss | 
					
						
						|  | from transformers import AutoModelForCausalLM, AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  | import evaluate | 
					
						
						|  | from evaluate import logging | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | _CITATION = """\ | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | _DESCRIPTION = """ | 
					
						
						|  | Perplexity (PPL) can be used for evaluating to what extent a dataset is similar to the distribution of text that a given model was trained on. | 
					
						
						|  | It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`. | 
					
						
						|  |  | 
					
						
						|  | For more information, see https://huggingface.co/docs/transformers/perplexity | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | _KWARGS_DESCRIPTION = """ | 
					
						
						|  | Args: | 
					
						
						|  | model_id (str): model used for calculating Perplexity | 
					
						
						|  | NOTE: Perplexity can only be calculated for causal language models. | 
					
						
						|  | This includes models such as gpt2, causal variations of bert, | 
					
						
						|  | causal versions of t5, and more (the full list can be found | 
					
						
						|  | in the AutoModelForCausalLM documentation here: | 
					
						
						|  | https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) | 
					
						
						|  |  | 
					
						
						|  | data (list of str): input data, each separate text snippet | 
					
						
						|  | is one list entry. | 
					
						
						|  | batch_size (int): the batch size to run texts through the model. Defaults to 16. | 
					
						
						|  | add_start_token (bool): whether to add the start token to the texts, | 
					
						
						|  | so the perplexity can include the probability of the first word. Defaults to True. | 
					
						
						|  | device (str): device to run on, defaults to 'cuda' when available | 
					
						
						|  | max_length (int): the maximum length to truncate input texts to. Should be set to the maximum length the model supports. Defaults to None. | 
					
						
						|  | Returns: | 
					
						
						|  | perplexity: dictionary containing the perplexity scores for the texts | 
					
						
						|  | in the input list, as well as the mean perplexity. If one of the input texts is | 
					
						
						|  | longer than the max input length of the model, then it is truncated to the | 
					
						
						|  | max length for the perplexity computation. | 
					
						
						|  | Examples: | 
					
						
						|  | Example 1: | 
					
						
						|  | >>> perplexity = evaluate.load("perplexity", module_type="measurement") | 
					
						
						|  | >>> data = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] | 
					
						
						|  | >>> results = perplexity.compute(model_id='gpt2', | 
					
						
						|  | ...                              add_start_token=False, | 
					
						
						|  | ...                              data=data) # doctest:+ELLIPSIS | 
					
						
						|  | >>> print(list(results.keys())) | 
					
						
						|  | ['perplexities', 'mean_perplexity'] | 
					
						
						|  | >>> print(round(results["mean_perplexity"], 0)) | 
					
						
						|  | 647.0 | 
					
						
						|  | >>> print(round(results["perplexities"][0], 0)) | 
					
						
						|  | 32.0 | 
					
						
						|  |  | 
					
						
						|  | Example 2: | 
					
						
						|  | >>> from datasets import load_dataset | 
					
						
						|  | >>> perplexity = evaluate.load("perplexity", module_type="measurement") | 
					
						
						|  | >>> data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP | 
					
						
						|  | >>> data = [s for s in data if s!=''] | 
					
						
						|  | >>> results = perplexity.compute(model_id='gpt2', | 
					
						
						|  | ...                              data=data) | 
					
						
						|  | >>> print(list(results.keys())) | 
					
						
						|  | ['perplexities', 'mean_perplexity'] | 
					
						
						|  | >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP | 
					
						
						|  | 576.76 | 
					
						
						|  | >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP | 
					
						
						|  | 889.28 | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) | 
					
						
						|  | class Perplexity(evaluate.Measurement): | 
					
						
						|  | def _info(self): | 
					
						
						|  | return evaluate.MeasurementInfo( | 
					
						
						|  | module_type="measurement", | 
					
						
						|  | description=_DESCRIPTION, | 
					
						
						|  | citation=_CITATION, | 
					
						
						|  | inputs_description=_KWARGS_DESCRIPTION, | 
					
						
						|  | features=datasets.Features( | 
					
						
						|  | { | 
					
						
						|  | "data": datasets.Value("string"), | 
					
						
						|  | } | 
					
						
						|  | ), | 
					
						
						|  | reference_urls=["https://huggingface.co/docs/transformers/perplexity"], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _compute( | 
					
						
						|  | self, data, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | if device is not None: | 
					
						
						|  | assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu." | 
					
						
						|  | if device == "gpu": | 
					
						
						|  | device = "cuda" | 
					
						
						|  | else: | 
					
						
						|  | device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
						|  |  | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained(model_id) | 
					
						
						|  | model = model.to(device) | 
					
						
						|  |  | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained(model_id) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if tokenizer.pad_token is None and batch_size > 1: | 
					
						
						|  | existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) | 
					
						
						|  |  | 
					
						
						|  | assert ( | 
					
						
						|  | len(existing_special_tokens) > 0 | 
					
						
						|  | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." | 
					
						
						|  |  | 
					
						
						|  | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) | 
					
						
						|  |  | 
					
						
						|  | if add_start_token and max_length: | 
					
						
						|  |  | 
					
						
						|  | assert ( | 
					
						
						|  | tokenizer.bos_token is not None | 
					
						
						|  | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" | 
					
						
						|  | max_tokenized_len = max_length - 1 | 
					
						
						|  | else: | 
					
						
						|  | max_tokenized_len = max_length | 
					
						
						|  |  | 
					
						
						|  | encodings = tokenizer( | 
					
						
						|  | data, | 
					
						
						|  | add_special_tokens=False, | 
					
						
						|  | padding=True, | 
					
						
						|  | truncation=True if max_tokenized_len else False, | 
					
						
						|  | max_length=max_tokenized_len, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | return_attention_mask=True, | 
					
						
						|  | ).to(device) | 
					
						
						|  |  | 
					
						
						|  | encoded_texts = encodings["input_ids"] | 
					
						
						|  | attn_masks = encodings["attention_mask"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if add_start_token: | 
					
						
						|  | assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." | 
					
						
						|  | else: | 
					
						
						|  | assert torch.all( | 
					
						
						|  | torch.ge(attn_masks.sum(1), 2) | 
					
						
						|  | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." | 
					
						
						|  |  | 
					
						
						|  | ppls = [] | 
					
						
						|  | loss_fct = CrossEntropyLoss(reduction="none") | 
					
						
						|  |  | 
					
						
						|  | for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): | 
					
						
						|  | end_index = min(start_index + batch_size, len(encoded_texts)) | 
					
						
						|  | encoded_batch = encoded_texts[start_index:end_index] | 
					
						
						|  | attn_mask = attn_masks[start_index:end_index] | 
					
						
						|  |  | 
					
						
						|  | if add_start_token: | 
					
						
						|  | bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) | 
					
						
						|  | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) | 
					
						
						|  | attn_mask = torch.cat( | 
					
						
						|  | [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | labels = encoded_batch | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | out_logits = model(encoded_batch, attention_mask=attn_mask).logits | 
					
						
						|  |  | 
					
						
						|  | shift_logits = out_logits[..., :-1, :].contiguous() | 
					
						
						|  | shift_labels = labels[..., 1:].contiguous() | 
					
						
						|  | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() | 
					
						
						|  |  | 
					
						
						|  | perplexity_batch = torch.exp( | 
					
						
						|  | (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) | 
					
						
						|  | / shift_attention_mask_batch.sum(1) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | ppls += perplexity_batch.tolist() | 
					
						
						|  |  | 
					
						
						|  | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} | 
					
						
						|  |  |