#!/usr/bin/env python3 """ tests/rag/run.py — Golden RAG evaluation. Reads tests/rag/golden.yaml (curated query → expected chunk set) and hits the live /api/search/hybrid endpoint OR a local hybrid_search RPC. Computes Recall@5 and MRR per query, plus aggregates. Writes a JSON report to tests/rag/last_run.json and compares with tests/rag/baseline.json. CI gate: if Recall@5 drops more than --max-recall-drop (default 0.05) from baseline, exit 1. Usage: python3 tests/rag/run.py # uses prod URL python3 tests/rag/run.py --url http://localhost:3000 # local dev python3 tests/rag/run.py --refresh-baseline # accept current as baseline python3 tests/rag/run.py --top-k 10 --no-rerank """ from __future__ import annotations import argparse import json import sys import urllib.parse import urllib.request import urllib.error from pathlib import Path try: import yaml except ImportError: sys.exit("pip install pyyaml") ROOT = Path(__file__).resolve().parent GOLDEN = ROOT / "golden.yaml" BASELINE = ROOT / "baseline.json" LAST_RUN = ROOT / "last_run.json" def search(base_url: str, q: str, lang: str, top_k: int, rerank: str) -> list[dict]: params = {"q": q, "lang": lang, "top_k": str(top_k)} if rerank == "never": params["rerank"] = "never" elif rerank == "always": params["rerank"] = "always" qs = urllib.parse.urlencode(params) url = f"{base_url.rstrip('/')}/api/search/hybrid?{qs}" try: with urllib.request.urlopen(url, timeout=30) as r: data = json.loads(r.read()) return data.get("hits", []) except urllib.error.HTTPError as e: sys.stderr.write(f" ! HTTP {e.code} on {q!r}\n") return [] except Exception as e: sys.stderr.write(f" ! {e} on {q!r}\n") return [] def evaluate(golden: list[dict], hits_by_id: dict[str, list[dict]], k: int) -> dict: """Per-query Recall@k + MRR. Negative-set queries (no expected chunks) pass when no hits are returned within the top-k.""" per_query: list[dict] = [] pos_recalls: list[float] = [] pos_mrrs: list[float] = [] neg_pass = 0 neg_total = 0 for q in golden: qid = q["id"] expected = {(e["doc"], e["chunk"]) for e in (q.get("expected_chunks") or [])} hits = hits_by_id.get(qid, []) topk = hits[:k] if not expected: # Negative-set: pass when fewer than k hits, OR when first hit is # weak enough that the model wouldn't latch onto it. We accept # any non-zero result count as failure to keep the metric strict. neg_total += 1 ok = len(topk) == 0 per_query.append({ "id": qid, "negative": True, "ok": ok, "n_hits": len(topk), }) if ok: neg_pass += 1 continue present = sum(1 for h in topk if (h.get("doc_id"), h.get("chunk_id")) in expected) recall = present / len(expected) # MRR — first matching position (1-indexed). 0 if none. rr = 0.0 for i, h in enumerate(topk, start=1): if (h.get("doc_id"), h.get("chunk_id")) in expected: rr = 1.0 / i break per_query.append({ "id": qid, "negative": False, "recall_at_k": round(recall, 4), "mrr": round(rr, 4), "n_expected": len(expected), "n_present": present, }) pos_recalls.append(recall) pos_mrrs.append(rr) return { "k": k, "n_queries": len(per_query), "n_positive": len(pos_recalls), "n_negative": neg_total, "recall_at_k": round(sum(pos_recalls) / len(pos_recalls), 4) if pos_recalls else 0.0, "mrr": round(sum(pos_mrrs) / len(pos_mrrs), 4) if pos_mrrs else 0.0, "negative_pass_rate": round(neg_pass / neg_total, 4) if neg_total else 1.0, "per_query": per_query, } def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--url", default="https://disclosure.top", help="Base URL of the deployment to evaluate") ap.add_argument("--top-k", type=int, default=5) ap.add_argument("--rerank", choices=["always", "when_top_k_gt", "never"], default="when_top_k_gt") ap.add_argument("--refresh-baseline", action="store_true", help="Overwrite baseline.json with this run (acknowledged regression).") ap.add_argument("--max-recall-drop", type=float, default=0.05) args = ap.parse_args() data = yaml.safe_load(GOLDEN.read_text()) queries = data["queries"] print(f"= running {len(queries)} queries against {args.url} (k={args.top_k}, rerank={args.rerank})") hits_by_id = {} for q in queries: hits = search(args.url, q["question"], q.get("lang", "pt"), top_k=max(args.top_k, 10), rerank=args.rerank) hits_by_id[q["id"]] = hits first = hits[0].get("chunk_id") if hits else "-" print(f" {q['id']:24s} → {len(hits):2d} hits (first={first})") report = evaluate(queries, hits_by_id, k=args.top_k) report["url"] = args.url report["top_k"] = args.top_k report["rerank"] = args.rerank LAST_RUN.write_text(json.dumps(report, indent=2)) print(f"\n— wrote {LAST_RUN}") print(f" Recall@{args.top_k} = {report['recall_at_k']:.4f}") print(f" MRR = {report['mrr']:.4f}") print(f" Negative pass = {report['negative_pass_rate']:.4f}") if args.refresh_baseline: BASELINE.write_text(json.dumps({ "url": args.url, "top_k": args.top_k, "rerank": args.rerank, "recall_at_k": report["recall_at_k"], "mrr": report["mrr"], "negative_pass_rate": report["negative_pass_rate"], }, indent=2)) print(f"\n✓ baseline refreshed: {BASELINE}") return 0 if not BASELINE.exists(): print("\n! no baseline yet — run with --refresh-baseline to create one") return 0 baseline = json.loads(BASELINE.read_text()) drop = baseline["recall_at_k"] - report["recall_at_k"] print(f"\n baseline Recall@{args.top_k} = {baseline['recall_at_k']:.4f} (Δ {-drop:+.4f})") if drop > args.max_recall_drop: print(f"\n✗ GATE FAILED: Recall@{args.top_k} dropped {drop:.4f} > {args.max_recall_drop}") return 1 print(f"\n✓ gate passed (drop ≤ {args.max_recall_drop})") return 0 if __name__ == "__main__": sys.exit(main())