Spaces:
Sleeping
Sleeping
import json | |
import pathlib | |
from copy import deepcopy | |
from typing import Callable | |
from functools import partial | |
import click | |
import pandas as pd | |
import pandera.pandas as pa | |
from tqdm.auto import tqdm | |
from langchain_core.runnables import Runnable | |
from src.common.data import load_dataset | |
from src.common.schema import DatasetSchema | |
from src.generate.config import GenerationConfig | |
from src.generate.schema import GeneratedDatasetSchema | |
from src.generate.answer import make_root_model, matches_type, string_to_type | |
from src.generate.generators import GenerationAnswer, GENERATORS_NAME_TO_FACTORY | |
def _save_temp_file( | |
row: dict, | |
result: GenerationAnswer, | |
temp_path: pathlib.Path, | |
) -> None: | |
temp_file = temp_path / f"{row[DatasetSchema.id_]}.json" | |
json.dump( | |
{ | |
DatasetSchema.id_: row[DatasetSchema.id_], | |
GeneratedDatasetSchema.generated_answer: result.model_dump(), | |
}, | |
open(temp_file, "w"), | |
ensure_ascii=False, | |
) | |
def _generate_single_answer( | |
row: dict, | |
build_chain: Callable[[type], Runnable], | |
temp_path: pathlib.Path = None, | |
) -> GenerationAnswer: | |
if temp_path and (temp_path / f"{row[DatasetSchema.id_]}.json").exists(): | |
return GenerationAnswer.model_validate( | |
json.load(open(temp_path / f"{row[DatasetSchema.id_]}.json", "r"))[GeneratedDatasetSchema.generated_answer] | |
) | |
answer_type = make_root_model(row[DatasetSchema.answer_type]) | |
chain = build_chain(answer_type) | |
row = dict(row) | |
row.pop(DatasetSchema.correct_answer, None) | |
result: GenerationAnswer = chain.invoke(row) | |
if temp_path: | |
_save_temp_file(row, result, temp_path) | |
return result | |
def _generate_answers( | |
df: pd.DataFrame, | |
build_chain: Callable[[type], Runnable], | |
use_tqdm: bool = True, | |
temp_path: pathlib.Path = None, | |
) -> pd.DataFrame: | |
if use_tqdm: | |
tqdm.pandas() | |
df[GeneratedDatasetSchema.generated_answer] = df.progress_apply( | |
partial( | |
_generate_single_answer, | |
build_chain=build_chain, | |
temp_path=temp_path, | |
), | |
axis=1, | |
) | |
else: | |
df[GeneratedDatasetSchema.generated_answer] = df.apply( | |
partial( | |
_generate_single_answer, | |
build_chain=build_chain, | |
temp_path=temp_path, | |
), | |
axis=1, | |
) | |
df = df[list(GeneratedDatasetSchema._collect_fields().keys())] | |
return df | |
def generate( | |
config_path: pathlib.Path = pathlib.Path("configs/ollama.yaml"), | |
output_path: pathlib.Path = pathlib.Path("./gemma3:4b.jsonl"), | |
temp_path: pathlib.Path = pathlib.Path("./tmp_gemma3:4b/"), | |
use_tqdm: bool = True, | |
): | |
output_path = pathlib.Path(output_path) | |
temp_path = pathlib.Path(temp_path) | |
output_path.parent.mkdir(parents=True, exist_ok=True) | |
temp_path.mkdir(parents=True, exist_ok=True) | |
config = GenerationConfig.from_file(config_path) | |
df = load_dataset() | |
# df = df.head(3) | |
build_chain_function = GENERATORS_NAME_TO_FACTORY[config.build_function] | |
build_chain_function = partial( | |
build_chain_function, | |
llm_class=config.llm_class, | |
structured_output_method=config.structured_output_method, | |
**config.kwargs | |
) | |
df = _generate_answers(df, build_chain_function, use_tqdm=use_tqdm, temp_path=temp_path) | |
df[GeneratedDatasetSchema.generated_answer] = df[GeneratedDatasetSchema.generated_answer].apply( | |
lambda x: x.model_dump() | |
) | |
df.to_json( | |
output_path, | |
lines=True, | |
orient="records", | |
force_ascii=False, | |
) | |
def _type_sanitycheck( | |
generated_df: pd.DataFrame, | |
) -> tuple[bool, str]: | |
generated_df[GeneratedDatasetSchema.generated_answer] = generated_df[GeneratedDatasetSchema.generated_answer].apply( | |
lambda x: GenerationAnswer.model_validate(deepcopy(x)) if not isinstance(x, GenerationAnswer) else x | |
) | |
dataset_df = load_dataset() | |
predicted_df = dataset_df.join( | |
generated_df.set_index(GeneratedDatasetSchema.id_), | |
on=DatasetSchema.id_, | |
rsuffix='_generated', | |
).dropna(subset=[GeneratedDatasetSchema.generated_answer]) | |
if len(predicted_df) == 0: | |
return False, "No valid predictions found." | |
TYPE_MATCH = "type_match" | |
predicted_df[TYPE_MATCH] = predicted_df.apply( | |
lambda row: matches_type( | |
row[GeneratedDatasetSchema.generated_answer].answer, | |
string_to_type(row[DatasetSchema.answer_type]), | |
), axis=1 | |
) | |
if not predicted_df[TYPE_MATCH].all(): | |
return False, f"Type mismatch found for {predicted_df[~predicted_df[TYPE_MATCH]][DatasetSchema.id_].tolist()}." | |
return True, f"All matched. Predicted count: {len(predicted_df)} of {len(dataset_df)}" | |
def type_sanitycheck( | |
file: pathlib.Path = pathlib.Path("./gemma3:4b.jsonl"), | |
): | |
df = pd.read_json(file, lines=True) | |
types_correct, message = _type_sanitycheck(df) | |
if not types_correct: | |
click.echo(f"β Type sanity check failed: {message}") | |
exit(1) | |
click.echo(f"β Type sanity check passed: {message}") | |
def cli(): | |
pass | |
cli.add_command(generate) | |
cli.add_command(type_sanitycheck) | |
if __name__ == "__main__": | |
cli() | |