pseudotensor commited on
Commit
32e765f
·
1 Parent(s): 6a1fd9e

Delete stopping.py

Browse files
Files changed (1) hide show
  1. stopping.py +0 -72
stopping.py DELETED
@@ -1,72 +0,0 @@
1
- import torch
2
- from transformers import StoppingCriteria, StoppingCriteriaList
3
-
4
- from prompter import PromptType
5
-
6
-
7
- class StoppingCriteriaSub(StoppingCriteria):
8
-
9
- def __init__(self, stops=[], encounters=[], device="cuda"):
10
- super().__init__()
11
- assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
- self.encounters = encounters
13
- self.stops = [stop.to(device) for stop in stops]
14
- self.num_stops = [0] * len(stops)
15
-
16
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
- for stopi, stop in enumerate(self.stops):
18
- if torch.all((stop == input_ids[0][-len(stop):])).item():
19
- self.num_stops[stopi] += 1
20
- if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
21
- # print("Stopped", flush=True)
22
- return True
23
- # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
24
- # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
25
- return False
26
-
27
-
28
- def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
29
- if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
30
- if prompt_type == PromptType.human_bot.name:
31
- # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
32
- # stopping only starts once output is beyond prompt
33
- # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
34
- stop_words = [human, bot, '\n' + human, '\n' + bot]
35
- encounters = [1, 2]
36
- elif prompt_type == PromptType.instruct_vicuna.name:
37
- # even below is not enough, generic strings and many ways to encode
38
- stop_words = [
39
- '### Human:',
40
- """
41
- ### Human:""",
42
- """
43
- ### Human:
44
- """,
45
- '### Assistant:',
46
- """
47
- ### Assistant:""",
48
- """
49
- ### Assistant:
50
- """,
51
- ]
52
- encounters = [1, 2]
53
- else:
54
- # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
55
- stop_words = ['### End']
56
- encounters = [1]
57
- stop_words_ids = [
58
- tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
59
- # handle single token case
60
- stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
61
- stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
62
- # avoid padding in front of tokens
63
- if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
64
- stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
65
- # handle fake \n added
66
- stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
67
- # build stopper
68
- stopping_criteria = StoppingCriteriaList(
69
- [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
70
- else:
71
- stopping_criteria = StoppingCriteriaList()
72
- return stopping_criteria