Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +34 -17
ts_generation_mixin.py
CHANGED
@@ -6,8 +6,38 @@ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
|
|
6 |
from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
|
7 |
from transformers.utils import ModelOutput
|
8 |
|
|
|
9 |
class TSGenerationMixin(GenerationMixin):
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def _greedy_search(
|
12 |
self,
|
13 |
input_ids: torch.Tensor,
|
@@ -26,19 +56,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
26 |
**model_kwargs,
|
27 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
28 |
input_ids = input_ids.to(self.device)
|
29 |
-
|
30 |
-
if len(input_ids.shape) == 2:
|
31 |
-
batch_size, cur_len = input_ids.shape
|
32 |
-
if cur_len < self.config.input_token_len:
|
33 |
-
raise ValueError(
|
34 |
-
f"Input length must be at least {self.config.input_token_len}")
|
35 |
-
elif cur_len % self.config.input_token_len != 0:
|
36 |
-
new_len = (cur_len // self.config.input_token_len) * \
|
37 |
-
self.config.input_token_len
|
38 |
-
input_ids = input_ids[:, -new_len:]
|
39 |
-
else:
|
40 |
-
raise ValueError('Input shape must be: [batch_size, seq_len]')
|
41 |
-
|
42 |
# init values
|
43 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
44 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
@@ -106,9 +124,8 @@ class TSGenerationMixin(GenerationMixin):
|
|
106 |
batch_size, dtype=torch.long, device=input_ids.device)
|
107 |
model_kwargs["cache_position"] = torch.arange(
|
108 |
cur_len, device=input_ids.device)
|
109 |
-
true_seq_len =
|
110 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
|
111 |
-
|
112 |
max_length = stopping_criteria.max_length
|
113 |
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
114 |
# prepare model inputs
|
@@ -129,7 +146,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
129 |
if synced_gpus and this_peer_finished:
|
130 |
continue # don't waste resources running the code we don't need
|
131 |
|
132 |
-
next_token_logits = outputs.logits
|
133 |
|
134 |
# pre-process distribution
|
135 |
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
@@ -212,7 +229,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
212 |
past_key_values=model_kwargs.get("past_key_values"),
|
213 |
)
|
214 |
else:
|
215 |
-
return input_ids[:, -(max_length -
|
216 |
|
217 |
def _update_model_kwargs_for_generation(
|
218 |
self,
|
|
|
6 |
from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
|
7 |
from transformers.utils import ModelOutput
|
8 |
|
9 |
+
|
10 |
class TSGenerationMixin(GenerationMixin):
|
11 |
|
12 |
+
@torch.no_grad()
|
13 |
+
def generate(
|
14 |
+
self,
|
15 |
+
inputs: Optional[torch.Tensor] = None,
|
16 |
+
generation_config: Optional[GenerationConfig] = None,
|
17 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
18 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
19 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
20 |
+
synced_gpus: Optional[bool] = None,
|
21 |
+
assistant_model: Optional["PreTrainedModel"] = None,
|
22 |
+
streamer: Optional["BaseStreamer"] = None,
|
23 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
24 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
25 |
+
**kwargs,
|
26 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
27 |
+
if len(inputs.shape) == 2:
|
28 |
+
batch_size, cur_len = inputs.shape
|
29 |
+
if cur_len < self.config.input_token_len:
|
30 |
+
raise ValueError(
|
31 |
+
f"Input length must be at least {self.config.input_token_len}")
|
32 |
+
elif cur_len % self.config.input_token_len != 0:
|
33 |
+
new_len = (cur_len // self.config.input_token_len) * \
|
34 |
+
self.config.input_token_len
|
35 |
+
inputs = inputs[:, -new_len:]
|
36 |
+
else:
|
37 |
+
raise ValueError('Input shape must be: [batch_size, seq_len]')
|
38 |
+
return super().generate(inputs=inputs, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, **kwargs)
|
39 |
+
|
40 |
+
|
41 |
def _greedy_search(
|
42 |
self,
|
43 |
input_ids: torch.Tensor,
|
|
|
56 |
**model_kwargs,
|
57 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
58 |
input_ids = input_ids.to(self.device)
|
59 |
+
batch_size, cur_len = input_ids.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# init values
|
61 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
62 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
124 |
batch_size, dtype=torch.long, device=input_ids.device)
|
125 |
model_kwargs["cache_position"] = torch.arange(
|
126 |
cur_len, device=input_ids.device)
|
127 |
+
true_seq_len = cur_len // self.config.input_token_len
|
128 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
|
|
|
129 |
max_length = stopping_criteria.max_length
|
130 |
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
131 |
# prepare model inputs
|
|
|
146 |
if synced_gpus and this_peer_finished:
|
147 |
continue # don't waste resources running the code we don't need
|
148 |
|
149 |
+
next_token_logits = outputs.logits
|
150 |
|
151 |
# pre-process distribution
|
152 |
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
|
|
229 |
past_key_values=model_kwargs.get("past_key_values"),
|
230 |
)
|
231 |
else:
|
232 |
+
return input_ids[:, -(max_length - cur_len):]
|
233 |
|
234 |
def _update_model_kwargs_for_generation(
|
235 |
self,
|