Spaces:
Runtime error
Runtime error
| import torch | |
| def sample_top_p(probs, p): | |
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
| probs_sum = torch.cumsum(probs_sort, dim=-1) | |
| mask = probs_sum - probs_sort > p | |
| probs_sort[mask] = 0.0 | |
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
| next_token = torch.multinomial(probs_sort, num_samples=1) | |
| next_token = torch.gather(probs_idx, -1, next_token) | |
| return next_token | |
| def format_prompt(instruction): | |
| PROMPT_DICT = { | |
| "prompt_input": ( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{instruction}\n\n### Response:" | |
| ) | |
| } | |
| return PROMPT_DICT["prompt_input"].format_map({'instruction': instruction}) | |