Anonymous
format and clean code
d27fe32
import csv
import json
import multiprocessing as mp
import os
from typing import Any, Dict, List, NewType, Optional, Union
import numpy as np
import yaml
from datasets import Dataset, DatasetDict, load_dataset
from easygoogletranslate import EasyGoogleTranslate
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from tqdm import tqdm
from yaml.loader import SafeLoader
LANGUAGE_TO_SUFFIX = {
"chinese_simplified": "zh-CN",
"french": "fr",
"portuguese": "pt",
"english": "en",
"arabic": "ar",
"hindi": "hi",
"indonesian": "id",
"amharic": "am",
"bengali": "bn",
"burmese": "my",
"chinese": "zh-CN",
"swahili": "sw",
"bulgarian": "bg",
"thai": "th",
"urdu": "ur",
"turkish": "tr",
"spanish": "es",
"chinese": "zh",
"greek": "el",
"german": "de",
}
NUMBER_TO_TAG = {0: "entailment", 1: "neutral", 2: "contradiction"}
PARAMS = NewType("PARAMS", Dict[str, Any])
def read_parameters(args_path) -> PARAMS:
with open(args_path) as f:
args = yaml.load(f, Loader=SafeLoader)
return args
def get_key(key_path):
with open(key_path) as f:
key = f.read().split("\n")[0]
return key
def _translate_example(
example: Dict[str, str], src_language: str, target_language: str
):
translator = EasyGoogleTranslate(
source_language=LANGUAGE_TO_SUFFIX[src_language],
target_language=LANGUAGE_TO_SUFFIX[target_language],
timeout=30,
)
try:
return {
"premise": translator.translate(example["premise"]),
"hypothesis": translator.translate(example["hypothesis"]),
"label": "",
}
except Exception as e:
print(e)
def choose_few_shot_examples(
train_dataset: Dataset,
few_shot_size: int,
context: List[str],
selection_criteria: str,
lang: str,
) -> List[Dict[str, Union[str, int]]]:
"""Selects few-shot examples from training datasets
Args:
train_dataset (Dataset): Training Dataset
few_shot_size (int): Number of few-shot examples
selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]
Returns:
List[Dict[str, Union[str, int]]]: Selected examples
"""
selected_examples = []
example_idxs = []
if selection_criteria == "first_k":
example_idxs = list(range(few_shot_size))
elif selection_criteria == "random":
example_idxs = (
np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
.astype(int)
.tolist()
)
ic_examples = [train_dataset[idx] for idx in example_idxs]
ic_examples = [
{
"premise": example["premise"],
"hypothesis": example["hypothesis"],
"label": NUMBER_TO_TAG[example["label"]],
}
for example in ic_examples
]
for idx, ic_language in enumerate(context):
(
selected_examples.append(ic_examples[idx])
if ic_language == lang
else (
selected_examples.append(
_translate_example(
example=ic_examples[idx],
src_language=lang,
target_language=ic_language,
)
)
)
)
return selected_examples
def load_xnli_dataset(
dataset_name: str,
lang: str,
split: str,
limit: int = 200,
) -> Union[Dataset, DatasetDict]:
"""
Args:
lang (str): Language for which xnli dataset is to be loaded
split (str): Train test of validation split of the model to load
dataset_frac (float): Fraction of examples to load. Defaults to 1.0
Returns:
Union[Dataset, DatasetDict]: huggingface dataset object
"""
if dataset_name == "indicxnli": ##PJ:To add except hindi
dataset = load_dataset("Divyanshu/indicxnli", LANGUAGE_TO_SUFFIX[lang])[split]
else:
dataset = load_dataset("xnli", LANGUAGE_TO_SUFFIX[lang])[split]
return dataset.select(np.arange(limit))
def construct_prompt(
instruction: str, test_example: dict, ic_examples: List[dict], zero_shot: bool
):
example_prompt = PromptTemplate(
input_variables=["premise", "hypothesis", "label"],
template="Premise: {premise}\n Hypothesis: {hypothesis} \n Label{label}",
)
zero_shot_template = (
f"""{instruction}""" + "\n hypothesis: {hypothesis} + \n Premise: {premise}" ""
)
prompt = (
FewShotPromptTemplate(
examples=ic_examples,
prefix=instruction,
example_prompt=example_prompt,
suffix="Premise: {premise} \n Hypothesis: {hypothesis}",
input_variables=["hypothesis", "premise"],
)
if not zero_shot
else PromptTemplate(
input_variables=["hypothesis", "premise"], template=zero_shot_template
)
)
return (
prompt.format(
hypothesis=test_example["hypothesis"], premise=test_example["premise"]
),
test_example["label"],
)
def dump_metrics(
lang: str,
config: Dict[str, str],
r1: float,
r2: float,
rL: float,
metric_logger_path: str,
):
# Check if the metric logger file exists
file_exists = os.path.exists(metric_logger_path)
# Open the CSV file in append mode
with open(metric_logger_path, "a", newline="") as f:
csvwriter = csv.writer(f, delimiter=",")
# Write header row if the file is newly created
if not file_exists:
header = [
"Language",
"Prefix",
"Input",
"Context",
"Output",
"R1",
"R2",
"RL",
]
csvwriter.writerow(header)
csvwriter.writerow(
[
lang,
config["prefix"],
config["input"],
config["context"][0],
config["output"],
r1,
r2,
rL,
]
)
def dump_predictions(idx, response, label, response_logger_file):
obj = {"q_idx": idx, "prediction": response, "label": label}
with open(response_logger_file, "a") as f:
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
def compute_rouge(scorer, pred, label):
score = scorer.score(pred, label)
return score["rouge1"], score["rouge2"], score["rougeL"]
def _translate_instruction(basic_instruction: str, target_language: str) -> str:
translator = EasyGoogleTranslate(
source_language="en",
target_language=LANGUAGE_TO_SUFFIX[target_language],
timeout=10,
)
return translator.translate(basic_instruction)
def _translate_prediction_to_output_language(
prediction: str, prediction_language: str, output_language: str
) -> str:
translator = EasyGoogleTranslate(
source_language=LANGUAGE_TO_SUFFIX[prediction_language],
target_language=LANGUAGE_TO_SUFFIX[output_language],
timeout=10,
)
return translator.translate(prediction)
def create_instruction(lang: str):
basic_instruction = f"""
You are an NLP assistant whose purpose is to solve Natural Language Inference (NLI) problems.
NLI is the task of determining the inference relation between two texts: entailment,
contradiction, or neutral.
Your answer should be one word of the following - entailment, contradiction, or neutral.
Pay attention: The output should be only one word!!!!
"""
return (
basic_instruction
if lang == "english"
else _translate_instruction(basic_instruction, target_language=lang)
)
def run_one_configuration(params: Optional[PARAMS] = None, zero: bool = False):
if not params:
params = read_parameters("../../parameters.yaml")
lang = params["selected_language"]
config = params["config"]
zero_shot = len(config["context"]) == 0
if not zero:
config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}"
else:
config_header = f"{config['input']}_{config['prefix']}_zero"
test_data = load_xnli_dataset(
dataset_name=params["dataset_name"],
lang=lang,
split="test",
limit=params["limit"],
)
pool = mp.Pool(processes=3)
# Iterate over test_data using tqdm for progress tracking
for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)):
# Apply asynchronous processing of each test example
pool.apply_async(
process_test_example,
args=(
test_data,
config_header,
idx,
test_example,
config,
zero_shot,
lang,
params,
),
)
# Close the pool and wait for all processes to finish
pool.close()
pool.join()
def process_test_example(
test_data, config_header, idx, test_example, config, zero_shot, lang, params
):
try:
instruction = create_instruction(lang=config["prefix"])
text_example = {
"premise": test_example["premise"],
"hypothesis": test_example["hypothesis"],
"label": test_example["label"],
}
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=len(config["context"]),
context=config["context"],
selection_criteria="random",
lang=params["selected_language"],
)
prompt, label = construct_prompt(
instruction=instruction,
test_example=text_example,
ic_examples=ic_examples,
zero_shot=zero_shot,
)
pred = get_prediction(
prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572
)
print(pred)
os.makedirs(
f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True
)
dump_predictions(
idx=idx,
response=pred,
label=label,
response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv",
)
except Exception as e:
# Handle exceptions here
print(f"Error processing example {idx}: {e}")
def construct_prompt(
instruction: str,
test_example: dict,
zero_shot: bool,
num_examples: int,
lang: str,
config: Dict[str, str],
dataset_name: str = "xnli",
):
if not instruction:
print(lang)
instruction = create_instruction(lang)
example_prompt = PromptTemplate(
input_variables=["premise", "hypothesis", "label"],
template="Premise {premise}\n Hypothesis {hypothesis} \n{label}",
)
zero_shot_template = (
f"""{instruction}""" + "\n Hypothesis: {hypothesis} + \n Premise: {premise}" ""
)
if not zero_shot:
try:
test_data = load_xnli_dataset(dataset_name, lang, split="test", limit=100)
except KeyError as e:
raise KeyError(
f"{lang} is not supported in {dataset_name} dataset, choose supported language in few-shot"
)
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=num_examples,
context=[config["context"]] * num_examples,
selection_criteria="random",
lang=lang,
)
prompt = (
FewShotPromptTemplate(
examples=ic_examples,
prefix=instruction,
example_prompt=example_prompt,
suffix="{premise} \n{hypothesis}",
input_variables=["hypothesis", "premise"],
)
if not zero_shot
else PromptTemplate(
input_variables=["hypothesis", "premise"], template=zero_shot_template
)
)
print("lang", lang)
print(config["input"], lang)
if config["input"] != lang:
test_example = _translate_example(
example=test_example, src_language=lang, target_language=config["input"]
)
return prompt.format(
hypothesis=test_example["hypothesis"], premise=test_example["premise"]
)