Upload predict.py
Browse files- predict.py +10 -6
predict.py
CHANGED
@@ -310,6 +310,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,
|
|
310 |
|
311 |
example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
|
312 |
|
|
|
|
|
|
|
313 |
while(len(history) > 0 and (len(example) < max_token)):
|
314 |
tmp = history.pop()
|
315 |
if tmp[0] == 'ASSISTANT':
|
@@ -333,7 +336,7 @@ def predict(model, text, tokenizer=None,
|
|
333 |
sft=True, convo_template = "",
|
334 |
device = "cuda",
|
335 |
model_name="AquilaChat2-7B",
|
336 |
-
history=
|
337 |
**kwargs):
|
338 |
|
339 |
vocab = tokenizer.get_vocab()
|
@@ -344,7 +347,7 @@ def predict(model, text, tokenizer=None,
|
|
344 |
template_map = {"AquilaChat2-7B": "aquila-v1",
|
345 |
"AquilaChat2-34B": "aquila-legacy",
|
346 |
"AquilaChat2-7B-16K": "aquila",
|
347 |
-
"AquilaChat2-34B-16K": "aquila
|
348 |
if not convo_template:
|
349 |
convo_template=template_map.get(model_name, "aquila-chat")
|
350 |
|
@@ -353,7 +356,7 @@ def predict(model, text, tokenizer=None,
|
|
353 |
topk = 1
|
354 |
temperature = 1.0
|
355 |
if sft:
|
356 |
-
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=
|
357 |
tokens = torch.tensor(tokens)[None,].to(device)
|
358 |
else :
|
359 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
@@ -435,8 +438,9 @@ def predict(model, text, tokenizer=None,
|
|
435 |
convert_tokens = convert_tokens[1:]
|
436 |
probs = probs[1:]
|
437 |
|
438 |
-
|
439 |
-
|
440 |
-
|
|
|
441 |
|
442 |
return out
|
|
|
310 |
|
311 |
example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
|
312 |
|
313 |
+
if history is None or not isinstance(history, list):
|
314 |
+
history = []
|
315 |
+
|
316 |
while(len(history) > 0 and (len(example) < max_token)):
|
317 |
tmp = history.pop()
|
318 |
if tmp[0] == 'ASSISTANT':
|
|
|
336 |
sft=True, convo_template = "",
|
337 |
device = "cuda",
|
338 |
model_name="AquilaChat2-7B",
|
339 |
+
history=None,
|
340 |
**kwargs):
|
341 |
|
342 |
vocab = tokenizer.get_vocab()
|
|
|
347 |
template_map = {"AquilaChat2-7B": "aquila-v1",
|
348 |
"AquilaChat2-34B": "aquila-legacy",
|
349 |
"AquilaChat2-7B-16K": "aquila",
|
350 |
+
"AquilaChat2-34B-16K": "aquila"}
|
351 |
if not convo_template:
|
352 |
convo_template=template_map.get(model_name, "aquila-chat")
|
353 |
|
|
|
356 |
topk = 1
|
357 |
temperature = 1.0
|
358 |
if sft:
|
359 |
+
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
|
360 |
tokens = torch.tensor(tokens)[None,].to(device)
|
361 |
else :
|
362 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
|
|
438 |
convert_tokens = convert_tokens[1:]
|
439 |
probs = probs[1:]
|
440 |
|
441 |
+
if isinstance(history, list):
|
442 |
+
# Update history
|
443 |
+
history.insert(0, ('ASSISTANT', out))
|
444 |
+
history.insert(0, ('USER', text))
|
445 |
|
446 |
return out
|