Spaces:
Runtime error
Runtime error
Update generate.py
Browse files- generate.py +5 -5
generate.py
CHANGED
|
@@ -86,10 +86,10 @@ class LmGeneration:
|
|
| 86 |
total_len = args.seq_length
|
| 87 |
|
| 88 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
-
tokens = torch.full((batch, total_len), self.tokenizer.
|
| 90 |
for idx, t in enumerate(prompt_tokens):
|
| 91 |
tokens[idx, : len(t)] = torch.tensor(t).long()
|
| 92 |
-
mask = tokens != self.tokenizer.
|
| 93 |
start_pos = min_prompt_len
|
| 94 |
prev_pos = 0
|
| 95 |
continue_exsample = [i for i in range(batch)]
|
|
@@ -118,7 +118,7 @@ class LmGeneration:
|
|
| 118 |
continue_exsample = []
|
| 119 |
for i, t in enumerate(tokens.tolist()):
|
| 120 |
try:
|
| 121 |
-
t.index(self.tokenizer.
|
| 122 |
except ValueError:
|
| 123 |
if cut_off is not None:
|
| 124 |
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
|
|
@@ -134,8 +134,8 @@ class LmGeneration:
|
|
| 134 |
for i, t in enumerate(tokens.tolist()):
|
| 135 |
t = t[: args.seq_length]
|
| 136 |
try:
|
| 137 |
-
t = t[: t.index(self.tokenizer.
|
| 138 |
-
t = t[: t.index(self.tokenizer.
|
| 139 |
except ValueError:
|
| 140 |
pass
|
| 141 |
decoder.append(self.tokenizer.decode(t))
|
|
|
|
| 86 |
total_len = args.seq_length
|
| 87 |
|
| 88 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
tokens = torch.full((batch, total_len), self.tokenizer.pad_token).to(device).long()
|
| 90 |
for idx, t in enumerate(prompt_tokens):
|
| 91 |
tokens[idx, : len(t)] = torch.tensor(t).long()
|
| 92 |
+
mask = tokens != self.tokenizer.pad_token
|
| 93 |
start_pos = min_prompt_len
|
| 94 |
prev_pos = 0
|
| 95 |
continue_exsample = [i for i in range(batch)]
|
|
|
|
| 118 |
continue_exsample = []
|
| 119 |
for i, t in enumerate(tokens.tolist()):
|
| 120 |
try:
|
| 121 |
+
t.index(self.tokenizer.eos_token)
|
| 122 |
except ValueError:
|
| 123 |
if cut_off is not None:
|
| 124 |
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
|
|
|
|
| 134 |
for i, t in enumerate(tokens.tolist()):
|
| 135 |
t = t[: args.seq_length]
|
| 136 |
try:
|
| 137 |
+
t = t[: t.index(self.tokenizer.pad_token)]
|
| 138 |
+
t = t[: t.index(self.tokenizer.eos_token)]
|
| 139 |
except ValueError:
|
| 140 |
pass
|
| 141 |
decoder.append(self.tokenizer.decode(t))
|