Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import StoppingCriteria, StoppingCriteriaList | |
| from enums import PromptType, t5_type | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops=[], stop_words=[], encounters=[], device="cuda", model_max_length=None, tokenizer=None): | |
| super().__init__() | |
| assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match" | |
| self.encounters = encounters | |
| self.stops = [stop.to(device) for stop in stops] | |
| self.stop_words = stop_words | |
| self.num_stops = [0] * len(stops) | |
| self.model_max_length = model_max_length | |
| self.tokenizer = tokenizer | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| #if self.tokenizer: | |
| # print('stop: %s' % self.tokenizer.decode(input_ids[0]), flush=True) | |
| for stopi, (stop, stop_word) in enumerate(zip(self.stops, self.stop_words)): | |
| current_block = input_ids[0][-len(stop):] | |
| stop_text = self.tokenizer.decode(current_block) | |
| len_new_tokens = current_block.shape[0] | |
| #if len(stop) <= len_new_tokens and torch.all((stop == input_ids[0][-len(stop):])).item(): | |
| if len(stop) <= len_new_tokens and stop_word in stop_text: | |
| self.num_stops[stopi] += 1 | |
| if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]: | |
| # print("Stopped", flush=True) | |
| return True | |
| if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length: | |
| # critical limit | |
| return True | |
| # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True) | |
| # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True) | |
| return False | |
| def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model, | |
| human='<human>:', bot="<bot>:", model_max_length=None, | |
| prompter=None, | |
| stop=None): | |
| stop_words = [] | |
| encounters = [] | |
| # FIXME: prompt_dict unused currently | |
| user_human_assistant_types = [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value), | |
| PromptType.instruct_vicuna.name] + \ | |
| [PromptType.guanaco.value, str(PromptType.guanaco.value), | |
| PromptType.guanaco.name] + \ | |
| [PromptType.one_shot.value, str(PromptType.one_shot.value), | |
| PromptType.one_shot.name] + \ | |
| [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value), | |
| PromptType.instruct_vicuna2.name] + \ | |
| [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value), | |
| PromptType.instruct_vicuna3.name] + \ | |
| [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value), | |
| PromptType.instruct_with_end.name] | |
| human_bot_types = [PromptType.human_bot.value, str(PromptType.human_bot.value), | |
| PromptType.human_bot.name] + \ | |
| [PromptType.human_bot_orig.value, str(PromptType.human_bot_orig.value), | |
| PromptType.human_bot_orig.name] | |
| all_types = user_human_assistant_types + human_bot_types | |
| if prompt_type in all_types: | |
| if prompt_type in human_bot_types: | |
| # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1] | |
| # stopping only starts once output is beyond prompt | |
| # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added | |
| stop_words = [human, bot, '\n' + human, '\n' + bot] | |
| encounters = [1, 2] | |
| elif prompt_type in user_human_assistant_types: | |
| # even below is not enough, generic strings and many ways to encode | |
| stop_words = [ | |
| '### Human:', | |
| """ | |
| ### Human:""", | |
| """ | |
| ### Human: | |
| """, | |
| """### Human: """, | |
| """### Human:""", | |
| '### Assistant:', | |
| """ | |
| ### Assistant:""", | |
| """ | |
| ### Assistant: | |
| """, | |
| """### Assistant: """, | |
| """### Assistant:""" | |
| ] | |
| if prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value), | |
| PromptType.instruct_vicuna2.name]: | |
| stop_words = [x.upper() for x in stop_words] | |
| if prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value), | |
| PromptType.instruct_vicuna3.name]: | |
| stop_words = [x.replace('Human', 'User') for x in stop_words] | |
| encounters = [1, 2] | |
| else: | |
| # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise | |
| stop_words = ['### End'] | |
| encounters = [1] | |
| elif prompter and prompter.terminate_response: | |
| stop_words = prompter.terminate_response | |
| encounters = [1] * len(stop_words) | |
| handle_newlines = [True] * len(stop_words) | |
| # add other stop words too if passed, e.g. for LangChain agents | |
| if stop: | |
| stop_words += stop | |
| encounters += [1] * len(stop) | |
| handle_newlines += [False] * len(stop) | |
| # get stop tokens | |
| stop_words_ids = [ | |
| tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] | |
| # handle single token case | |
| stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids] | |
| stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0] | |
| # avoid padding in front of tokens | |
| if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug | |
| stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids] | |
| if tokenizer._unk_token: # use hidden variable to avoid annoying properly logger bug | |
| stop_words_ids = [x[1:] if x[0] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids] | |
| stop_words_ids = [x[:-1] if x[-1] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids] | |
| if tokenizer._eos_token: # use hidden variable to avoid annoying properly logger bug | |
| stop_words_ids = [x[:-1] if x[-1] == tokenizer.eos_token_id and len(x) > 1 else x for x in stop_words_ids] | |
| if tokenizer._bos_token: # use hidden variable to avoid annoying properly logger bug | |
| stop_words_ids = [x[1:] if x[0] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids] | |
| stop_words_ids = [x[:-1] if x[-1] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids] | |
| if base_model and t5_type(base_model): | |
| # T5 encoder converts internal double space to space+new line, so fix | |
| for stopi, stop_word_id in enumerate(stop_words_ids): | |
| start = stop_word_id[0:1] | |
| mlist = stop_word_id[1:-1] | |
| end = stop_word_id[-1:] | |
| mlist = [tokenizer.vocab[' '] if x == tokenizer.vocab['\n'] else x for x in mlist] | |
| stop_words_ids[stopi] = torch.tensor(list(start) + list(mlist) + list(end), device=stop_word_id.device) | |
| # handle fake \n added | |
| stop_words_ids = [x[1:] if y[0] == '\n' and handle_newline else x for x, y, handle_newline in | |
| zip(stop_words_ids, stop_words, handle_newlines)] | |
| if stop_words_ids: | |
| # build stopper | |
| stopping_criteria = StoppingCriteriaList( | |
| [StoppingCriteriaSub(stops=stop_words_ids, | |
| stop_words=stop_words, | |
| encounters=encounters, device=device, | |
| model_max_length=model_max_length, tokenizer=tokenizer)]) | |
| else: | |
| # nothing to stop on | |
| stopping_criteria = StoppingCriteriaList() | |
| return stopping_criteria | |