RAG_Eval / scripts /prep_annotations.py
Rom89823974978's picture
Updated codebase
12409b1
"""
Runs RAG pipeline over dataset(s) and saves partial results
for manual annotation.
"""
import argparse
import json
from pathlib import Path
from typing import Any, Dict
from evaluation import PipelineConfig, RetrieverConfig, GeneratorConfig, CrossEncoderConfig, StatsConfig, LoggingConfig, RAGPipeline
from evaluation.utils.logger import init_logging
import yaml
def merge_dataclass(dc_cls, override: Dict[str, Any]):
from dataclasses import asdict
base = asdict(dc_cls())
base.update({k: v for k, v in override.items() if v is not None})
return dc_cls(**base)
def load_pipeline_config(yaml_path: Path) -> PipelineConfig:
data = yaml.safe_load(yaml_path.read_text())
return PipelineConfig(
retriever=merge_dataclass(RetrieverConfig, data.get("retriever", {})),
generator=merge_dataclass(GeneratorConfig, data.get("generator", {})),
reranker=merge_dataclass(CrossEncoderConfig, data.get("reranker", {})),
stats=merge_dataclass(StatsConfig, data.get("stats", {})),
logging=merge_dataclass(LoggingConfig, data.get("logging", {})),
)
def read_jsonl(path: Path) -> list[dict]:
with path.open() as f:
return [json.loads(line) for line in f]
def write_jsonl(path: Path, rows: list[dict]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w") as f:
for row in rows:
f.write(json.dumps(row) + "\n")
def main(argv=None):
ap = argparse.ArgumentParser()
ap.add_argument("--config", type=Path, required=True)
ap.add_argument("--datasets", nargs="+", type=Path, required=True)
ap.add_argument("--outdir", type=Path, default=Path("outputs/for_annotation"))
args = ap.parse_args(argv)
init_logging(log_dir=args.outdir / "logs")
cfg = load_pipeline_config(args.config)
pipe = RAGPipeline(cfg)
for dataset in args.datasets:
queries = read_jsonl(dataset)
output_dir = args.outdir / dataset.stem / args.config.stem
output_path = output_dir / "unlabeled_results.jsonl"
if output_path.exists():
print(f"Skipping {dataset.name} – already exists.")
continue
rows = []
for q in queries:
result = pipe.run(q["question"])
entry = {
"question": q["question"],
"retrieved_docs": result.get("retrieved_docs", []),
"generated_answer": result.get("generated_answer", ""),
"metrics": result.get("metrics", {}),
# Human annotators will add these
"human_correct": None,
"human_faithful": None
}
rows.append(entry)
write_jsonl(output_path, rows)
print(f"Wrote {len(rows)} results to {output_path}")
if __name__ == "__main__":
main()