|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
from vllm import LLM, SamplingParams |
|
from arguments import get_args |
|
from dataset_conv import get_chatqa2_input, preprocess |
|
from tqdm import tqdm |
|
import torch |
|
import os |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ['VLLM_NCCL_SO_PATH'] = '/usr/local/lib/python3.8/dist-packages/nvidia/nccl/lib/libnccl.so.2' |
|
|
|
def get_prompt_list(args): |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) |
|
|
|
|
|
data_list = preprocess(args.sample_input_file, inference_only=True, retrieved_neighbours=args.use_retrieved_neighbours) |
|
print("number of total data_list:", len(data_list)) |
|
if args.start_idx != -1 and args.end_idx != -1: |
|
print("getting data from %d to %d" % (args.start_idx, args.end_idx)) |
|
data_list = data_list[args.start_idx:args.end_idx] |
|
|
|
print("number of test samples in the dataset:", len(data_list)) |
|
prompt_list = get_chatqa2_input(data_list, args.eval_dataset, tokenizer, num_ctx=args.num_ctx, max_output_len=args.max_tokens, max_seq_length=args.max_seq_length) |
|
|
|
return prompt_list |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
|
|
bos_token = "<|begin_of_text|>" |
|
|
|
model_path = args.model_folder |
|
|
|
|
|
prompt_list = get_prompt_list(args) |
|
|
|
output_path = os.path.join(model_path, "outputs") |
|
if not os.path.exists(output_path): |
|
os.mkdir(output_path) |
|
|
|
|
|
if args.start_idx != -1 and args.end_idx != -1: |
|
if args.use_retrieved_neighbours: |
|
output_datapath = os.path.join(output_path, "%s_output_%dto%d_ctx%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx, args.num_ctx)) |
|
else: |
|
output_datapath = os.path.join(output_path, "%s_output_%dto%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx)) |
|
else: |
|
if args.use_retrieved_neighbours: |
|
output_datapath = os.path.join(output_path, "%s_output_ctx%d.txt" % (args.eval_dataset, args.num_ctx)) |
|
else: |
|
output_datapath = os.path.join(output_path, "%s_output.txt" % (args.eval_dataset)) |
|
|
|
|
|
sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=args.max_tokens) |
|
|
|
|
|
model_vllm = LLM(model_path, tensor_parallel_size=8, dtype=torch.bfloat16) |
|
print(model_vllm) |
|
|
|
output_list = [] |
|
for prompt in tqdm(prompt_list): |
|
prompt = bos_token + prompt |
|
output = model_vllm.generate([prompt], sampling_params)[0] |
|
generated_text = output.outputs[0].text |
|
generated_text = generated_text.strip().replace("\n", " ") |
|
|
|
|
|
if "<|eot_id|>" in generated_text: |
|
idx = generated_text.index("<|eot_id|>") |
|
generated_text = generated_text[:idx] |
|
if "<|end_of_text|>" in generated_text: |
|
idx = generated_text.index("<|end_of_text|>") |
|
generated_text = generated_text[:idx] |
|
|
|
print("="*80) |
|
print("prompt:", prompt) |
|
print("-"*80) |
|
print("generated_text:", generated_text) |
|
print("="*80) |
|
output_list.append(generated_text) |
|
|
|
print("writing to %s" % output_datapath) |
|
with open(output_datapath, "w", encoding="utf-8") as f: |
|
for output in output_list: |
|
f.write(output + "\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|