Spaces:
Sleeping
Sleeping
""" | |
Runs RAG pipeline over dataset(s) and saves partial results | |
for manual annotation. | |
""" | |
import argparse | |
import json | |
from pathlib import Path | |
from typing import Any, Dict | |
from evaluation import PipelineConfig, RetrieverConfig, GeneratorConfig, CrossEncoderConfig, StatsConfig, LoggingConfig, RAGPipeline | |
from evaluation.utils.logger import init_logging | |
import yaml | |
def merge_dataclass(dc_cls, override: Dict[str, Any]): | |
from dataclasses import asdict | |
base = asdict(dc_cls()) | |
base.update({k: v for k, v in override.items() if v is not None}) | |
return dc_cls(**base) | |
def load_pipeline_config(yaml_path: Path) -> PipelineConfig: | |
data = yaml.safe_load(yaml_path.read_text()) | |
return PipelineConfig( | |
retriever=merge_dataclass(RetrieverConfig, data.get("retriever", {})), | |
generator=merge_dataclass(GeneratorConfig, data.get("generator", {})), | |
reranker=merge_dataclass(CrossEncoderConfig, data.get("reranker", {})), | |
stats=merge_dataclass(StatsConfig, data.get("stats", {})), | |
logging=merge_dataclass(LoggingConfig, data.get("logging", {})), | |
) | |
def read_jsonl(path: Path) -> list[dict]: | |
with path.open() as f: | |
return [json.loads(line) for line in f] | |
def write_jsonl(path: Path, rows: list[dict]) -> None: | |
path.parent.mkdir(parents=True, exist_ok=True) | |
with path.open("w") as f: | |
for row in rows: | |
f.write(json.dumps(row) + "\n") | |
def main(argv=None): | |
ap = argparse.ArgumentParser() | |
ap.add_argument("--config", type=Path, required=True) | |
ap.add_argument("--datasets", nargs="+", type=Path, required=True) | |
ap.add_argument("--outdir", type=Path, default=Path("outputs/for_annotation")) | |
args = ap.parse_args(argv) | |
init_logging(log_dir=args.outdir / "logs") | |
cfg = load_pipeline_config(args.config) | |
pipe = RAGPipeline(cfg) | |
for dataset in args.datasets: | |
queries = read_jsonl(dataset) | |
output_dir = args.outdir / dataset.stem / args.config.stem | |
output_path = output_dir / "unlabeled_results.jsonl" | |
if output_path.exists(): | |
print(f"Skipping {dataset.name} – already exists.") | |
continue | |
rows = [] | |
for q in queries: | |
result = pipe.run(q["question"]) | |
entry = { | |
"question": q["question"], | |
"retrieved_docs": result.get("retrieved_docs", []), | |
"generated_answer": result.get("generated_answer", ""), | |
"metrics": result.get("metrics", {}), | |
# Human annotators will add these | |
"human_correct": None, | |
"human_faithful": None | |
} | |
rows.append(entry) | |
write_jsonl(output_path, rows) | |
print(f"Wrote {len(rows)} results to {output_path}") | |
if __name__ == "__main__": | |
main() | |