AddieFoote commited on
Commit
41b98bc
·
1 Parent(s): 5e61c3e

truncate after bos

Browse files
Files changed (1) hide show
  1. app.py +7 -0
app.py CHANGED
@@ -17,6 +17,13 @@ def generate_response(prompt):
17
  outputs = model.generate(**inputs, max_new_tokens=5, temperature=1.0)
18
  input_length = inputs['input_ids'].shape[1]
19
  new_token_ids = outputs[0][input_length:]
 
 
 
 
 
 
 
20
  new_tokens = tokenizer.decode(new_token_ids, skip_special_tokens=False)
21
  return new_tokens
22
 
 
17
  outputs = model.generate(**inputs, max_new_tokens=5, temperature=1.0)
18
  input_length = inputs['input_ids'].shape[1]
19
  new_token_ids = outputs[0][input_length:]
20
+ bos_token_id = tokenizer.bos_token_id
21
+ if bos_token_id is not None:
22
+ bos_positions = (new_token_ids == bos_token_id).nonzero(as_tuple=True)[0]
23
+ if len(bos_positions) > 0:
24
+ # Truncate at first BOS token
25
+ first_bos_pos = bos_positions[0].item()
26
+ new_token_ids = new_token_ids[:first_bos_pos]
27
  new_tokens = tokenizer.decode(new_token_ids, skip_special_tokens=False)
28
  return new_tokens
29