Commit
·
41b98bc
1
Parent(s):
5e61c3e
truncate after bos
Browse files
app.py
CHANGED
@@ -17,6 +17,13 @@ def generate_response(prompt):
|
|
17 |
outputs = model.generate(**inputs, max_new_tokens=5, temperature=1.0)
|
18 |
input_length = inputs['input_ids'].shape[1]
|
19 |
new_token_ids = outputs[0][input_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
new_tokens = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
21 |
return new_tokens
|
22 |
|
|
|
17 |
outputs = model.generate(**inputs, max_new_tokens=5, temperature=1.0)
|
18 |
input_length = inputs['input_ids'].shape[1]
|
19 |
new_token_ids = outputs[0][input_length:]
|
20 |
+
bos_token_id = tokenizer.bos_token_id
|
21 |
+
if bos_token_id is not None:
|
22 |
+
bos_positions = (new_token_ids == bos_token_id).nonzero(as_tuple=True)[0]
|
23 |
+
if len(bos_positions) > 0:
|
24 |
+
# Truncate at first BOS token
|
25 |
+
first_bos_pos = bos_positions[0].item()
|
26 |
+
new_token_ids = new_token_ids[:first_bos_pos]
|
27 |
new_tokens = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
28 |
return new_tokens
|
29 |
|