Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
·
5dd37ab
1
Parent(s):
bb74d9f
Ensure `input_ids` are on the correct device when predicting
Browse files- src/predict.py +1 -1
src/predict.py
CHANGED
|
@@ -171,7 +171,7 @@ DEFAULT_TOKEN_PREFIX = 'summarize: '
|
|
| 171 |
def predict_sponsor_text(text, model, tokenizer):
|
| 172 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
| 173 |
input_ids = tokenizer(
|
| 174 |
-
f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids
|
| 175 |
|
| 176 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
| 177 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|
|
|
|
| 171 |
def predict_sponsor_text(text, model, tokenizer):
|
| 172 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
| 173 |
input_ids = tokenizer(
|
| 174 |
+
f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids.to(device())
|
| 175 |
|
| 176 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
| 177 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|