#!/usr/bin/env python3
"""A/B test prompts against structured test cases.

Supports:
- --input JSON payload or stdin JSON payload
- --prompt-a/--prompt-b or file variants
- --cases-file for test suite JSON
- optional --runner-cmd with {prompt} and {input} placeholders

If runner command is omitted, script performs static prompt quality scoring only.
"""

import argparse
import json
import re
import shlex
import subprocess
import sys
from dataclasses import dataclass, asdict
from pathlib import Path
from statistics import mean
from typing import Any, Dict, List, Optional


class CLIError(Exception):
    """Raised for expected CLI errors."""


@dataclass
class CaseScore:
    case_id: str
    prompt_variant: str
    score: float
    matched_expected: int
    missed_expected: int
    forbidden_hits: int
    regex_matches: int
    output_length: int


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="A/B test prompts against test cases.")
    parser.add_argument("--input", help="JSON input file for full payload.")
    parser.add_argument("--prompt-a", help="Prompt A text.")
    parser.add_argument("--prompt-b", help="Prompt B text.")
    parser.add_argument("--prompt-a-file", help="Path to prompt A file.")
    parser.add_argument("--prompt-b-file", help="Path to prompt B file.")
    parser.add_argument("--cases-file", help="Path to JSON test cases array.")
    parser.add_argument(
        "--runner-cmd",
        help="External command template, e.g. 'llm --prompt {prompt} --input {input}'.",
    )
    parser.add_argument("--format", choices=["text", "json"], default="text", help="Output format.")
    return parser.parse_args()


def read_text_file(path: Optional[str]) -> Optional[str]:
    if not path:
        return None
    try:
        return Path(path).read_text(encoding="utf-8")
    except Exception as exc:
        raise CLIError(f"Failed reading file {path}: {exc}") from exc


def load_payload(args: argparse.Namespace) -> Dict[str, Any]:
    if args.input:
        try:
            return json.loads(Path(args.input).read_text(encoding="utf-8"))
        except Exception as exc:
            raise CLIError(f"Failed reading --input payload: {exc}") from exc

    if not sys.stdin.isatty():
        raw = sys.stdin.read().strip()
        if raw:
            try:
                return json.loads(raw)
            except json.JSONDecodeError as exc:
                raise CLIError(f"Invalid JSON from stdin: {exc}") from exc

    payload: Dict[str, Any] = {}

    prompt_a = args.prompt_a or read_text_file(args.prompt_a_file)
    prompt_b = args.prompt_b or read_text_file(args.prompt_b_file)
    if prompt_a:
        payload["prompt_a"] = prompt_a
    if prompt_b:
        payload["prompt_b"] = prompt_b

    if args.cases_file:
        try:
            payload["cases"] = json.loads(Path(args.cases_file).read_text(encoding="utf-8"))
        except Exception as exc:
            raise CLIError(f"Failed reading --cases-file: {exc}") from exc

    if args.runner_cmd:
        payload["runner_cmd"] = args.runner_cmd

    return payload


def run_runner(runner_cmd: str, prompt: str, case_input: str) -> str:
    cmd = runner_cmd.format(prompt=prompt, input=case_input)
    parts = shlex.split(cmd)
    try:
        proc = subprocess.run(parts, text=True, capture_output=True, check=True)
    except subprocess.CalledProcessError as exc:
        raise CLIError(f"Runner command failed: {exc.stderr.strip()}") from exc
    return proc.stdout.strip()


def static_output(prompt: str, case_input: str) -> str:
    rendered = prompt.replace("{{input}}", case_input)
    return rendered


def score_output(case: Dict[str, Any], output: str, prompt_variant: str) -> CaseScore:
    case_id = str(case.get("id", "case"))
    expected = [str(x) for x in case.get("expected_contains", []) if str(x)]
    forbidden = [str(x) for x in case.get("forbidden_contains", []) if str(x)]
    regexes = [str(x) for x in case.get("expected_regex", []) if str(x)]

    matched_expected = sum(1 for item in expected if item.lower() in output.lower())
    missed_expected = len(expected) - matched_expected
    forbidden_hits = sum(1 for item in forbidden if item.lower() in output.lower())
    regex_matches = 0
    for pattern in regexes:
        try:
            if re.search(pattern, output, flags=re.MULTILINE):
                regex_matches += 1
        except re.error:
            pass

    score = 100.0
    score -= missed_expected * 15
    score -= forbidden_hits * 25
    score += regex_matches * 8

    # Heuristic penalty for unbounded verbosity
    if len(output) > 4000:
        score -= 10
    if len(output.strip()) < 10:
        score -= 10

    score = max(0.0, min(100.0, score))

    return CaseScore(
        case_id=case_id,
        prompt_variant=prompt_variant,
        score=score,
        matched_expected=matched_expected,
        missed_expected=missed_expected,
        forbidden_hits=forbidden_hits,
        regex_matches=regex_matches,
        output_length=len(output),
    )


def aggregate(scores: List[CaseScore]) -> Dict[str, Any]:
    if not scores:
        return {"average": 0.0, "min": 0.0, "max": 0.0, "cases": 0}
    vals = [s.score for s in scores]
    return {
        "average": round(mean(vals), 2),
        "min": round(min(vals), 2),
        "max": round(max(vals), 2),
        "cases": len(vals),
    }


def main() -> int:
    args = parse_args()
    payload = load_payload(args)

    prompt_a = str(payload.get("prompt_a", "")).strip()
    prompt_b = str(payload.get("prompt_b", "")).strip()
    cases = payload.get("cases", [])
    runner_cmd = payload.get("runner_cmd")

    if not prompt_a or not prompt_b:
        raise CLIError("Both prompt_a and prompt_b are required (flags or JSON payload).")
    if not isinstance(cases, list) or not cases:
        raise CLIError("cases must be a non-empty array.")

    scores_a: List[CaseScore] = []
    scores_b: List[CaseScore] = []

    for case in cases:
        if not isinstance(case, dict):
            continue
        case_input = str(case.get("input", "")).strip()

        output_a = run_runner(runner_cmd, prompt_a, case_input) if runner_cmd else static_output(prompt_a, case_input)
        output_b = run_runner(runner_cmd, prompt_b, case_input) if runner_cmd else static_output(prompt_b, case_input)

        scores_a.append(score_output(case, output_a, "A"))
        scores_b.append(score_output(case, output_b, "B"))

    agg_a = aggregate(scores_a)
    agg_b = aggregate(scores_b)
    winner = "A" if agg_a["average"] >= agg_b["average"] else "B"

    result = {
        "summary": {
            "winner": winner,
            "prompt_a": agg_a,
            "prompt_b": agg_b,
            "mode": "runner" if runner_cmd else "static",
        },
        "case_scores": {
            "prompt_a": [asdict(item) for item in scores_a],
            "prompt_b": [asdict(item) for item in scores_b],
        },
    }

    if args.format == "json":
        print(json.dumps(result, indent=2))
    else:
        print("Prompt A/B test result")
        print(f"- mode: {result['summary']['mode']}")
        print(f"- winner: {winner}")
        print(f"- prompt A avg: {agg_a['average']}")
        print(f"- prompt B avg: {agg_b['average']}")
        print("Case details:")
        for item in scores_a + scores_b:
            print(
                f"- case={item.case_id} variant={item.prompt_variant} score={item.score} "
                f"expected+={item.matched_expected} forbidden={item.forbidden_hits} regex={item.regex_matches}"
            )

    return 0


if __name__ == "__main__":
    try:
        raise SystemExit(main())
    except CLIError as exc:
        print(f"ERROR: {exc}", file=sys.stderr)
        raise SystemExit(2)
