#!/usr/bin/env python3
"""
Pen Test Report Generator - Generate structured penetration testing reports from findings.

Table of Contents:
    PentestReportGenerator - Main class for report generation
        __init__               - Initialize with findings data
        generate_markdown()    - Generate markdown report
        generate_json()        - Generate structured JSON report
        _executive_summary()   - Build executive summary section
        _findings_table()      - Build severity-sorted findings table
        _detailed_findings()   - Build detailed findings with evidence
        _remediation_matrix()  - Build effort vs. impact remediation matrix
        _calculate_risk_score() - Calculate overall risk score
    main() - CLI entry point

Usage:
    python pentest_report_generator.py --findings findings.json --format md --output report.md
    python pentest_report_generator.py --findings findings.json --format json
    python pentest_report_generator.py --findings findings.json --format md
"""

import argparse
import json
import sys
from dataclasses import dataclass, asdict, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional


@dataclass
class Finding:
    """A single pen test finding."""
    title: str
    severity: str  # critical, high, medium, low, info
    cvss_score: float
    category: str
    description: str
    evidence: str
    impact: str
    remediation: str
    cvss_vector: str = ""
    references: List[str] = field(default_factory=list)
    effort: str = "medium"  # low, medium, high — remediation effort


SEVERITY_ORDER = {"critical": 5, "high": 4, "medium": 3, "low": 2, "info": 1}


class PentestReportGenerator:
    """Generate professional penetration testing reports from structured findings."""

    def __init__(self, findings: List[Finding], metadata: Optional[Dict] = None):
        self.findings = sorted(findings, key=lambda f: SEVERITY_ORDER.get(f.severity, 0), reverse=True)
        self.metadata = metadata or {}
        self.generated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    def generate_markdown(self) -> str:
        """Generate a complete markdown pen test report."""
        sections = []
        sections.append(self._header())
        sections.append(self._executive_summary())
        sections.append(self._scope_section())
        sections.append(self._findings_table())
        sections.append(self._detailed_findings())
        sections.append(self._remediation_matrix())
        sections.append(self._methodology_section())
        sections.append(self._appendix())
        return "\n\n".join(sections)

    def generate_json(self) -> Dict:
        """Generate structured JSON report."""
        return {
            "report_metadata": {
                "title": self.metadata.get("title", "Penetration Test Report"),
                "target": self.metadata.get("target", "Not specified"),
                "tester": self.metadata.get("tester", "Not specified"),
                "date_range": self.metadata.get("date_range", "Not specified"),
                "generated_at": self.generated_at,
                "overall_risk_score": self._calculate_risk_score(),
                "overall_risk_level": self._risk_level(),
            },
            "summary": {
                "total_findings": len(self.findings),
                "critical": len([f for f in self.findings if f.severity == "critical"]),
                "high": len([f for f in self.findings if f.severity == "high"]),
                "medium": len([f for f in self.findings if f.severity == "medium"]),
                "low": len([f for f in self.findings if f.severity == "low"]),
                "info": len([f for f in self.findings if f.severity == "info"]),
            },
            "findings": [asdict(f) for f in self.findings],
            "remediation_priority": self._remediation_priority_list(),
        }

    def _header(self) -> str:
        title = self.metadata.get("title", "Penetration Test Report")
        target = self.metadata.get("target", "Not specified")
        tester = self.metadata.get("tester", "Not specified")
        date_range = self.metadata.get("date_range", "Not specified")
        lines = [
            f"# {title}",
            "",
            "| Field | Value |",
            "|-------|-------|",
            f"| **Target** | {target} |",
            f"| **Tester** | {tester} |",
            f"| **Date Range** | {date_range} |",
            f"| **Report Generated** | {self.generated_at} |",
            f"| **Overall Risk** | {self._risk_level()} (Score: {self._calculate_risk_score():.1f}/10) |",
            f"| **Total Findings** | {len(self.findings)} |",
        ]
        return "\n".join(lines)

    def _executive_summary(self) -> str:
        critical = len([f for f in self.findings if f.severity == "critical"])
        high = len([f for f in self.findings if f.severity == "high"])
        medium = len([f for f in self.findings if f.severity == "medium"])
        low = len([f for f in self.findings if f.severity == "low"])
        info = len([f for f in self.findings if f.severity == "info"])
        risk_score = self._calculate_risk_score()
        risk_level = self._risk_level()

        lines = [
            "## Executive Summary",
            "",
            f"This penetration test identified **{len(self.findings)} findings** across the target application. "
            f"The overall risk level is **{risk_level}** with a score of **{risk_score:.1f}/10**.",
            "",
            "### Finding Severity Distribution",
            "",
            "| Severity | Count |",
            "|----------|-------|",
            f"| Critical | {critical} |",
            f"| High | {high} |",
            f"| Medium | {medium} |",
            f"| Low | {low} |",
            f"| Informational | {info} |",
        ]

        # Top 3 findings
        if self.findings:
            lines.append("")
            lines.append("### Top Priority Findings")
            lines.append("")
            for i, f in enumerate(self.findings[:3], 1):
                lines.append(f"{i}. **{f.title}** ({f.severity.upper()}, CVSS {f.cvss_score}) — {f.impact[:120]}")

        # Risk assessment
        lines.append("")
        if critical > 0:
            lines.append("> **CRITICAL RISK**: Immediate remediation required. Critical vulnerabilities "
                         "allow attackers to compromise the system with minimal effort.")
        elif high > 0:
            lines.append("> **HIGH RISK**: Prompt remediation recommended. High-severity vulnerabilities "
                         "pose significant risk of exploitation.")
        elif medium > 0:
            lines.append("> **MODERATE RISK**: Remediation should be planned within the next sprint. "
                         "Medium findings may be chained for greater impact.")
        else:
            lines.append("> **LOW RISK**: The application has a reasonable security posture. "
                         "Address low-severity findings during regular maintenance.")

        return "\n".join(lines)

    def _scope_section(self) -> str:
        scope = self.metadata.get("scope", "Full application security assessment")
        exclusions = self.metadata.get("exclusions", "None specified")
        test_type = self.metadata.get("test_type", "Gray box")
        lines = [
            "## Scope",
            "",
            f"- **In Scope**: {scope}",
            f"- **Exclusions**: {exclusions}",
            f"- **Test Type**: {test_type}",
        ]
        return "\n".join(lines)

    def _findings_table(self) -> str:
        lines = [
            "## Findings Overview",
            "",
            "| # | Severity | CVSS | Title | Category |",
            "|---|----------|------|-------|----------|",
        ]
        for i, f in enumerate(self.findings, 1):
            sev_badge = f.severity.upper()
            lines.append(f"| {i} | {sev_badge} | {f.cvss_score} | {f.title} | {f.category} |")
        return "\n".join(lines)

    def _detailed_findings(self) -> str:
        lines = ["## Detailed Findings"]
        for i, f in enumerate(self.findings, 1):
            lines.append("")
            lines.append(f"### {i}. {f.title}")
            lines.append("")
            lines.append(f"**Severity:** {f.severity.upper()} | **CVSS:** {f.cvss_score}"
                         + (f" | **Vector:** `{f.cvss_vector}`" if f.cvss_vector else ""))
            lines.append(f"**Category:** {f.category}")
            lines.append("")
            lines.append("#### Description")
            lines.append("")
            lines.append(f"{f.description}")
            lines.append("")
            lines.append("#### Evidence")
            lines.append("")
            lines.append("```")
            lines.append(f"{f.evidence}")
            lines.append("```")
            lines.append("")
            lines.append("#### Impact")
            lines.append("")
            lines.append(f"{f.impact}")
            lines.append("")
            lines.append("#### Remediation")
            lines.append("")
            lines.append(f"{f.remediation}")
            if f.references:
                lines.append("")
                lines.append("#### References")
                lines.append("")
                for ref in f.references:
                    lines.append(f"- {ref}")
        return "\n".join(lines)

    def _remediation_matrix(self) -> str:
        lines = [
            "## Remediation Priority Matrix",
            "",
            "Prioritize remediation based on severity and effort:",
            "",
            "| # | Finding | Severity | Effort | Priority |",
            "|---|---------|----------|--------|----------|",
        ]
        for i, f in enumerate(self.findings, 1):
            priority = self._compute_priority(f)
            lines.append(f"| {i} | {f.title} | {f.severity.upper()} | {f.effort} | {priority} |")

        lines.append("")
        lines.append("**Priority Key:** P1 = Fix immediately, P2 = Fix this sprint, "
                      "P3 = Fix this quarter, P4 = Backlog")
        return "\n".join(lines)

    def _methodology_section(self) -> str:
        lines = [
            "## Methodology",
            "",
            "Testing followed the OWASP Testing Guide v4.2 and PTES (Penetration Testing Execution Standard):",
            "",
            "1. **Reconnaissance** — Mapped attack surface, identified endpoints and technologies",
            "2. **Vulnerability Discovery** — Automated scanning + manual testing for OWASP Top 10",
            "3. **Exploitation** — Validated findings with proof-of-concept (non-destructive)",
            "4. **Post-Exploitation** — Assessed lateral movement and data access potential",
            "5. **Reporting** — Documented findings with evidence and remediation guidance",
        ]
        return "\n".join(lines)

    def _appendix(self) -> str:
        lines = [
            "## Appendix",
            "",
            "### CVSS Scoring Reference",
            "",
            "| Score Range | Severity |",
            "|-------------|----------|",
            "| 9.0 - 10.0 | Critical |",
            "| 7.0 - 8.9 | High |",
            "| 4.0 - 6.9 | Medium |",
            "| 0.1 - 3.9 | Low |",
            "| 0.0 | Informational |",
            "",
            "### Disclaimer",
            "",
            "This report represents a point-in-time assessment. New vulnerabilities may emerge after "
            "the testing period. Regular security assessments are recommended.",
            "",
            f"---\n\n*Report generated on {self.generated_at}*",
        ]
        return "\n".join(lines)

    def _calculate_risk_score(self) -> float:
        """Calculate overall risk score (0-10) based on findings."""
        if not self.findings:
            return 0.0
        # Weighted by severity
        weights = {"critical": 10, "high": 7, "medium": 4, "low": 1.5, "info": 0.5}
        total_weight = sum(weights.get(f.severity, 0) for f in self.findings)
        # Normalize: cap at 10, scale based on number of findings
        score = min(10.0, total_weight / max(len(self.findings) * 0.5, 1))
        return round(score, 1)

    def _risk_level(self) -> str:
        """Return risk level string based on score."""
        score = self._calculate_risk_score()
        if score >= 9.0:
            return "CRITICAL"
        elif score >= 7.0:
            return "HIGH"
        elif score >= 4.0:
            return "MEDIUM"
        elif score > 0:
            return "LOW"
        return "NONE"

    def _compute_priority(self, finding: Finding) -> str:
        """Compute remediation priority from severity and effort."""
        sev = SEVERITY_ORDER.get(finding.severity, 0)
        effort_map = {"low": 3, "medium": 2, "high": 1}
        effort_val = effort_map.get(finding.effort, 2)
        score = sev * effort_val
        if score >= 12:
            return "P1"
        elif score >= 8:
            return "P2"
        elif score >= 4:
            return "P3"
        return "P4"

    def _remediation_priority_list(self) -> List[Dict]:
        """Return ordered list of remediation priorities for JSON output."""
        result = []
        for f in self.findings:
            result.append({
                "title": f.title,
                "severity": f.severity,
                "effort": f.effort,
                "priority": self._compute_priority(f),
                "remediation": f.remediation,
            })
        return result


def load_findings(path: str) -> tuple:
    """Load findings from a JSON file."""
    try:
        content = Path(path).read_text(encoding="utf-8")
        data = json.loads(content)
    except (OSError, json.JSONDecodeError) as e:
        print(f"Error loading findings: {e}", file=sys.stderr)
        sys.exit(1)

    # Support both list-of-findings and object-with-metadata formats
    metadata = {}
    findings_data = data
    if isinstance(data, dict):
        metadata = data.get("metadata", {})
        findings_data = data.get("findings", [])

    findings = []
    for item in findings_data:
        findings.append(Finding(
            title=item.get("title", "Untitled Finding"),
            severity=item.get("severity", "medium"),
            cvss_score=float(item.get("cvss_score", 0.0)),
            category=item.get("category", "Uncategorized"),
            description=item.get("description", ""),
            evidence=item.get("evidence", "No evidence provided"),
            impact=item.get("impact", ""),
            remediation=item.get("remediation", ""),
            cvss_vector=item.get("cvss_vector", ""),
            references=item.get("references", []),
            effort=item.get("effort", "medium"),
        ))
    return findings, metadata


def generate_sample_findings() -> str:
    """Generate a sample findings JSON for reference."""
    sample = [
        {
            "title": "SQL Injection in Login Endpoint",
            "severity": "critical",
            "cvss_score": 9.8,
            "cvss_vector": "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H",
            "category": "A03:2021 - Injection",
            "description": "The /api/login endpoint is vulnerable to SQL injection via the email parameter.",
            "evidence": "Request: POST /api/login {\"email\": \"' OR 1=1--\", \"password\": \"x\"}\nResponse: 200 OK with admin session token",
            "impact": "Full database access, authentication bypass, potential remote code execution.",
            "remediation": "Use parameterized queries. Replace string concatenation with prepared statements.",
            "references": ["https://cwe.mitre.org/data/definitions/89.html"],
            "effort": "low"
        },
        {
            "title": "Stored XSS in User Profile",
            "severity": "high",
            "cvss_score": 7.1,
            "cvss_vector": "CVSS:3.1/AV:N/AC:L/PR:L/UI:R/S:C/C:L/I:L/A:N",
            "category": "A03:2021 - Injection",
            "description": "The user profile 'bio' field does not sanitize HTML input.",
            "evidence": "Submitted <img src=x onerror=alert(document.cookie)> in bio field.\nVisiting the profile page executes the payload.",
            "impact": "Session hijacking, account takeover, phishing via stored malicious content.",
            "remediation": "Sanitize all user input with DOMPurify. Implement Content-Security-Policy.",
            "references": ["https://cwe.mitre.org/data/definitions/79.html"],
            "effort": "low"
        }
    ]
    return json.dumps(sample, indent=2)


def main():
    parser = argparse.ArgumentParser(
        description="Pen Test Report Generator — Generate professional penetration testing reports from structured findings.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s --findings findings.json --format md --output report.md
  %(prog)s --findings findings.json --format json
  %(prog)s --sample > sample_findings.json

Findings JSON format:
  A JSON array of objects with: title, severity, cvss_score, category,
  description, evidence, impact, remediation, cvss_vector, references, effort.

  Use --sample to generate a template.
        """,
    )
    parser.add_argument("--findings", metavar="FILE",
                        help="Path to findings JSON file")
    parser.add_argument("--format", choices=["md", "json"], default="md",
                        help="Output format (default: md)")
    parser.add_argument("--output", metavar="FILE",
                        help="Output file path (default: stdout)")
    parser.add_argument("--json", action="store_true", dest="json_shortcut",
                        help="Shortcut for --format json")
    parser.add_argument("--sample", action="store_true",
                        help="Print sample findings JSON and exit")
    args = parser.parse_args()

    if args.sample:
        print(generate_sample_findings())
        return

    if not args.findings:
        parser.error("--findings is required (use --sample to generate a template)")

    if not Path(args.findings).exists():
        print(f"Error: File not found: {args.findings}", file=sys.stderr)
        sys.exit(1)

    output_format = "json" if args.json_shortcut else args.format
    findings, metadata = load_findings(args.findings)

    if not findings:
        print("No findings loaded. Check the JSON file format.", file=sys.stderr)
        sys.exit(1)

    generator = PentestReportGenerator(findings=findings, metadata=metadata)

    if output_format == "json":
        result = json.dumps(generator.generate_json(), indent=2)
    else:
        result = generator.generate_markdown()

    if args.output:
        Path(args.output).write_text(result, encoding="utf-8")
        print(f"Report written to {args.output}")
    else:
        print(result)


if __name__ == "__main__":
    main()
