File size: 6,239 Bytes
f4c3446 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from vllm import LLM, SamplingParams
import multiprocessing
import time
import gc
import torch
import pdb
import sqlite3
from concurrent.futures import ThreadPoolExecutor
#from openai_call import query_azure_openai_chatgpt_chat
def label_transform(label):
if label==1:
return 'neutral'
if label==0:
return 'entailment'
if label==2:
return 'contradiction'
sampling_params = SamplingParams(temperature=0.0,max_tokens=600, top_p=0.95)
def valid_results_collect(model_path,valid_data,task):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# multiprocessing.set_start_method('spawn')
trained_model=LLM(model=model_path,gpu_memory_utilization=0.95)
start_t=time.time()
if task=='sql':
failed_cases,correct_cases=sql_evaluation(trained_model,valid_data)
elif task=='nli':
failed_cases,correct_cases=nli_evaluation(trained_model,valid_data)
del trained_model
end_t=time.time()
print('time',start_t-end_t)
gc.collect() # Run garbage collection to free up memory
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
#torch.cuda.synchronize()
#torch.cuda.empty_cache()
#torch.cuda.synchronize()
time.sleep(10)
return failed_cases,correct_cases
def extract_answer_prediction_nli(predicted_output):
sens=predicted_output.split('.')
final_sens=[sen for sen in sens if 'final' in sen]
for sen in final_sens:
if extract_answer(sen):
return extract_answer(sen)
return
def extract_answer(text):
if 'neutral' in text.lower():
return 'neutral'
if 'contradiction' in text.lower():
return 'contradiction'
if 'entailment' in text.lower():
return 'entailment'
return None
def process_batch(data_batch,trained_model,failed_cases,correct_cases):
batch_prompts = [data['Input'] for data in data_batch]
outputs = trained_model.generate(batch_prompts, sampling_params)
results = []
labels=['entailment','contradiction','neutral']
for data, output in zip(data_batch, outputs):
# pdb.set_trace()
predicted_output = output.outputs[0].text
predicted_res = extract_answer_prediction_nli(predicted_output)
label = extract_answer(data['Output'].split('is')[-1])
print(predicted_res,label,'\n')
if not predicted_res:
# pdb.set_trace()
predicted_res=predicted_output
# print(label,predicted_output) # if 'contradiction #label_transform(data['Output'])
# pdb.set_trace()
# print(predicted_res,label,'\n')
non_labels = [lbl for lbl in labels if lbl != label]
if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels):
failed_cases.append((data['Input'],predicted_res,label,data))
else:
correct_cases.append((data['Input'],predicted_res,label,data))
return failed_cases,correct_cases
def nli_evaluation(trained_model,valid_data):
id=0
failed_cases=[]
correct_cases=[]
batch_size=500
batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)]
for batch in batched_data:
failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases)
#for data in valid_data:
# prompt=data['Input']
# output=trained_model.generate(prompt, sampling_params)
# predicted_output=output[0].outputs[0].text
# predicted_res=extract_answer_prediction_nli(predicted_output) #$try:
# # predicted_res=extract_answer(predicted_output.split('final')[-1].split('is')[1].split('.')[0])
#except:
# predicted_res=extract_answer(predicted_output.split('is')[-1])
# label=extract_answer(data['Output'].split('is')[-1])
# print(label,predicted_res)
# if not predicted_res:
# pdb.set_trace()
# predicted_res=''
# if 'contradiction #label_transform(data['Output'])
# pdb.set_trace()
# if label not in predicted_res:
# failed_cases.append((id,prompt,predicted_res,label,data))
# else:
# correct_cases.append((id,prompt,predicted_res,label,data))
# id+=1
#id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res
return failed_cases,correct_cases
def sql_evaluation(trained_model,valid_data):
id=0
failed_cases=[]
correct_cases=[]
for triple in valid_data:
db_id,prompt,ground_truth=triple
prompt=prompt.replace('SELECT','')
db_path='/dccstor/obsidian_llm/yiduo/AgentBench/DAMO-ConvAI/bird/data/train/train_databases/{0}/{0}.sqlite'.format(db_id)
prompt+=' To generate the SQL query to' #print(db_path) #pdb.set_trace()
conn = sqlite3.connect(db_path)
output=trained_model.generate(prompt, sampling_params) #pdb.set_trace()
predicted_sql = output[0].outputs[0].text
#pdb.set_trace()
prior_pred=predicted_sql.split('final SQL')[0]
try:
predicted_sql = predicted_sql.split('final SQL')[1].strip()
except:
predicted_sql = 'SELECT'+predicted_sql.split('SELECT')[1]
predicted_sql=predicted_sql.split(';')[0]
predicted_sql=predicted_sql[predicted_sql.find('SELECT'):] #[1:]
cursor = conn.cursor()
# pdb.set_trace()
try:
cursor.execute(predicted_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
#print('results',predicted_res,'truth',ground_truth_res,'\n')
if set(predicted_res) != set(ground_truth_res):
failed_cases.append((id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res))
else:
correct_cases.append((id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res))
except Exception as e:
failed_cases.append((id,prompt,predicted_sql,valid_data[id],ground_truth,str(Exception)+str(e)))
return failed_cases,correct_cases
|