"""Grade observation traces against compliance specs using LLM classification."""

from __future__ import annotations

from dataclasses import dataclass

from scripts.classifier import classify_events
from scripts.parser import ComplianceSpec, ObservationEvent, Step


@dataclass(frozen=True)
class StepResult:
    step_id: str
    detected: bool
    evidence: tuple[ObservationEvent, ...]
    failure_reason: str | None


@dataclass(frozen=True)
class ComplianceResult:
    spec_id: str
    steps: tuple[StepResult, ...]
    compliance_rate: float
    recommend_hook_promotion: bool
    classification: dict[str, list[int]]


def _check_temporal_order(
    step: Step,
    event: ObservationEvent,
    resolved: dict[str, list[ObservationEvent]],
    classified: dict[str, list[ObservationEvent]],
) -> str | None:
    """Check before_step/after_step constraints. Returns failure reason or None."""
    if step.detector.after_step is not None:
        after_events = resolved.get(step.detector.after_step)
        if after_events is None:
            after_events = classified.get(step.detector.after_step, [])
        if not after_events:
            return f"after_step '{step.detector.after_step}' not yet detected"
        latest_after = max(e.timestamp for e in after_events)
        if event.timestamp <= latest_after:
            return (
                f"must occur after '{step.detector.after_step}' "
                f"(last at {latest_after}), but found at {event.timestamp}"
            )

    if step.detector.before_step is not None:
        # Look ahead using LLM classification results
        before_events = resolved.get(step.detector.before_step)
        if before_events is None:
            before_events = classified.get(step.detector.before_step, [])
        if before_events:
            earliest_before = min(e.timestamp for e in before_events)
            if event.timestamp >= earliest_before:
                return (
                    f"must occur before '{step.detector.before_step}' "
                    f"(first at {earliest_before}), but found at {event.timestamp}"
                )

    return None


def grade(
    spec: ComplianceSpec,
    trace: list[ObservationEvent],
    classifier_model: str = "haiku",
) -> ComplianceResult:
    """Grade a trace against a compliance spec using LLM classification."""
    sorted_trace = sorted(trace, key=lambda e: e.timestamp)

    # Step 1: LLM classifies all events in one batch call
    classification = classify_events(spec, sorted_trace, model=classifier_model)

    # Convert indices to events
    classified: dict[str, list[ObservationEvent]] = {
        step_id: [sorted_trace[i] for i in indices if 0 <= i < len(sorted_trace)]
        for step_id, indices in classification.items()
    }

    # Step 2: Check temporal ordering (deterministic)
    resolved: dict[str, list[ObservationEvent]] = {}
    step_results: list[StepResult] = []

    for step in spec.steps:
        candidates = classified.get(step.id, [])
        matched: list[ObservationEvent] = []
        failure_reason: str | None = None

        for event in candidates:
            temporal_fail = _check_temporal_order(step, event, resolved, classified)
            if temporal_fail is None:
                matched.append(event)
                break
            else:
                failure_reason = temporal_fail

        detected = len(matched) > 0
        if detected:
            resolved[step.id] = matched
        elif failure_reason is None:
            failure_reason = f"no matching event classified for step '{step.id}'"

        step_results.append(StepResult(
            step_id=step.id,
            detected=detected,
            evidence=tuple(matched),
            failure_reason=failure_reason if not detected else None,
        ))

    required_ids = {s.id for s in spec.steps if s.required}
    required_steps = [s for s in step_results if s.step_id in required_ids]
    detected_required = sum(1 for s in required_steps if s.detected)
    total_required = len(required_steps)

    compliance_rate = detected_required / total_required if total_required > 0 else 0.0

    return ComplianceResult(
        spec_id=spec.id,
        steps=tuple(step_results),
        compliance_rate=compliance_rate,
        recommend_hook_promotion=compliance_rate < spec.threshold_promote_to_hook,
        classification=classification,
    )
