"""Classify tool calls against compliance steps using LLM."""

from __future__ import annotations

import json
import logging
import subprocess
from pathlib import Path

logger = logging.getLogger(__name__)

from scripts.parser import ComplianceSpec, ObservationEvent

PROMPTS_DIR = Path(__file__).parent.parent / "prompts"


def classify_events(
    spec: ComplianceSpec,
    trace: list[ObservationEvent],
    model: str = "haiku",
) -> dict[str, list[int]]:
    """Classify which tool calls match which compliance steps.

    Returns {step_id: [event_indices]} via a single LLM call.
    """
    if not trace:
        return {}

    steps_desc = "\n".join(
        f"- {step.id}: {step.detector.description}"
        for step in spec.steps
    )

    tool_calls = "\n".join(
        f"[{i}] {event.tool}: input={event.input[:500]} output={event.output[:200]}"
        for i, event in enumerate(trace)
    )

    prompt_template = (PROMPTS_DIR / "classifier.md").read_text()
    prompt = (
        prompt_template
        .replace("{steps_description}", steps_desc)
        .replace("{tool_calls}", tool_calls)
    )

    result = subprocess.run(
        ["claude", "-p", prompt, "--model", model, "--output-format", "text"],
        capture_output=True,
        text=True,
        timeout=60,
    )

    if result.returncode != 0:
        raise RuntimeError(
            f"classifier subprocess failed (rc={result.returncode}): "
            f"{result.stderr[:500]}"
        )

    return _parse_classification(result.stdout)


def _parse_classification(text: str) -> dict[str, list[int]]:
    """Parse LLM classification output into {step_id: [event_indices]}."""
    text = text.strip()
    # Strip markdown fences
    lines = text.splitlines()
    if lines and lines[0].startswith("```"):
        lines = lines[1:]
    if lines and lines[-1].startswith("```"):
        lines = lines[:-1]
    cleaned = "\n".join(lines)

    try:
        parsed = json.loads(cleaned)
        if not isinstance(parsed, dict):
            logger.warning("Classifier returned non-dict JSON: %s", type(parsed).__name__)
            return {}
        return {
            k: [int(i) for i in v]
            for k, v in parsed.items()
            if isinstance(v, list)
        }
    except (json.JSONDecodeError, ValueError, TypeError) as e:
        logger.warning("Failed to parse classification output: %s", e)
        return {}
