#!/usr/bin/env python3
"""
Vulnerability Assessor - Scan dependencies for known CVEs and security issues.

Table of Contents:
    VulnerabilityAssessor - Main class for dependency vulnerability assessment
        __init__         - Initialize with target path and options
        assess()         - Run complete vulnerability assessment
        scan_npm()       - Scan package.json for npm vulnerabilities
        scan_python()    - Scan requirements.txt for Python vulnerabilities
        scan_go()        - Scan go.mod for Go vulnerabilities
        _parse_package_json() - Parse npm package.json
        _parse_requirements() - Parse Python requirements.txt
        _parse_go_mod()  - Parse Go go.mod
        _check_vulnerability() - Check package against CVE database
        _calculate_risk_score() - Calculate overall risk score
    main() - CLI entry point

Usage:
    python vulnerability_assessor.py /path/to/project
    python vulnerability_assessor.py /path/to/project --severity high
    python vulnerability_assessor.py /path/to/project --output report.json --json
"""

import os
import sys
import json
import re
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime


@dataclass
class Vulnerability:
    """Represents a dependency vulnerability."""
    cve_id: str
    package: str
    installed_version: str
    fixed_version: str
    severity: str  # critical, high, medium, low
    cvss_score: float
    description: str
    ecosystem: str  # npm, pypi, go
    recommendation: str


class VulnerabilityAssessor:
    """Assess project dependencies for known vulnerabilities."""

    # Known CVE database (simplified - real implementation would query NVD/OSV)
    KNOWN_CVES = {
        # npm packages
        'lodash': [
            {'version_lt': '4.17.21', 'cve': 'CVE-2021-23337', 'cvss': 7.2,
             'severity': 'high', 'desc': 'Command injection in lodash',
             'fixed': '4.17.21'},
            {'version_lt': '4.17.19', 'cve': 'CVE-2020-8203', 'cvss': 7.4,
             'severity': 'high', 'desc': 'Prototype pollution in lodash',
             'fixed': '4.17.19'},
        ],
        'axios': [
            {'version_lt': '1.6.0', 'cve': 'CVE-2023-45857', 'cvss': 6.5,
             'severity': 'medium', 'desc': 'CSRF token exposure in axios',
             'fixed': '1.6.0'},
        ],
        'express': [
            {'version_lt': '4.17.3', 'cve': 'CVE-2022-24999', 'cvss': 7.5,
             'severity': 'high', 'desc': 'Open redirect in express',
             'fixed': '4.17.3'},
        ],
        'jsonwebtoken': [
            {'version_lt': '9.0.0', 'cve': 'CVE-2022-23529', 'cvss': 9.8,
             'severity': 'critical', 'desc': 'JWT algorithm confusion attack',
             'fixed': '9.0.0'},
        ],
        'minimist': [
            {'version_lt': '1.2.6', 'cve': 'CVE-2021-44906', 'cvss': 9.8,
             'severity': 'critical', 'desc': 'Prototype pollution in minimist',
             'fixed': '1.2.6'},
        ],
        'node-fetch': [
            {'version_lt': '2.6.7', 'cve': 'CVE-2022-0235', 'cvss': 8.8,
             'severity': 'high', 'desc': 'Information exposure in node-fetch',
             'fixed': '2.6.7'},
        ],
        # Python packages
        'django': [
            {'version_lt': '4.2.8', 'cve': 'CVE-2023-46695', 'cvss': 7.5,
             'severity': 'high', 'desc': 'DoS via file uploads in Django',
             'fixed': '4.2.8'},
        ],
        'requests': [
            {'version_lt': '2.31.0', 'cve': 'CVE-2023-32681', 'cvss': 6.1,
             'severity': 'medium', 'desc': 'Proxy-Auth header leak in requests',
             'fixed': '2.31.0'},
        ],
        'pillow': [
            {'version_lt': '10.0.1', 'cve': 'CVE-2023-44271', 'cvss': 7.5,
             'severity': 'high', 'desc': 'DoS via crafted image in Pillow',
             'fixed': '10.0.1'},
        ],
        'cryptography': [
            {'version_lt': '41.0.4', 'cve': 'CVE-2023-38325', 'cvss': 7.5,
             'severity': 'high', 'desc': 'NULL pointer dereference in cryptography',
             'fixed': '41.0.4'},
        ],
        'pyyaml': [
            {'version_lt': '6.0.1', 'cve': 'CVE-2020-14343', 'cvss': 9.8,
             'severity': 'critical', 'desc': 'Arbitrary code execution in PyYAML',
             'fixed': '6.0.1'},
        ],
        'urllib3': [
            {'version_lt': '2.0.6', 'cve': 'CVE-2023-43804', 'cvss': 8.1,
             'severity': 'high', 'desc': 'Cookie header leak in urllib3',
             'fixed': '2.0.6'},
        ],
        # Go packages
        'golang.org/x/crypto': [
            {'version_lt': 'v0.17.0', 'cve': 'CVE-2023-48795', 'cvss': 5.9,
             'severity': 'medium', 'desc': 'SSH prefix truncation attack',
             'fixed': 'v0.17.0'},
        ],
        'golang.org/x/net': [
            {'version_lt': 'v0.17.0', 'cve': 'CVE-2023-44487', 'cvss': 7.5,
             'severity': 'high', 'desc': 'HTTP/2 rapid reset attack',
             'fixed': 'v0.17.0'},
        ],
    }

    SEVERITY_ORDER = {'critical': 0, 'high': 1, 'medium': 2, 'low': 3}

    def __init__(
        self,
        target_path: str,
        severity_threshold: str = "low",
        verbose: bool = False
    ):
        """
        Initialize the vulnerability assessor.

        Args:
            target_path: Directory to scan for dependency files
            severity_threshold: Minimum severity to report
            verbose: Enable verbose output
        """
        self.target_path = Path(target_path)
        self.severity_threshold = severity_threshold
        self.verbose = verbose
        self.vulnerabilities: List[Vulnerability] = []
        self.packages_scanned = 0
        self.files_scanned = 0

    def assess(self) -> Dict:
        """
        Run complete vulnerability assessment.

        Returns:
            Dict with assessment results
        """
        print(f"Vulnerability Assessor - Scanning: {self.target_path}")
        print(f"Severity threshold: {self.severity_threshold}")
        print()

        if not self.target_path.exists():
            return {"status": "error", "message": f"Path not found: {self.target_path}"}

        start_time = datetime.now()

        # Scan npm dependencies
        package_json = self.target_path / "package.json"
        if package_json.exists():
            self.scan_npm(package_json)
            self.files_scanned += 1

        # Scan Python dependencies
        requirements_files = [
            "requirements.txt",
            "requirements-dev.txt",
            "requirements-prod.txt",
            "pyproject.toml"
        ]
        for req_file in requirements_files:
            req_path = self.target_path / req_file
            if req_path.exists():
                self.scan_python(req_path)
                self.files_scanned += 1

        # Scan Go dependencies
        go_mod = self.target_path / "go.mod"
        if go_mod.exists():
            self.scan_go(go_mod)
            self.files_scanned += 1

        # Scan package-lock.json for transitive dependencies
        package_lock = self.target_path / "package-lock.json"
        if package_lock.exists():
            self.scan_npm_lock(package_lock)
            self.files_scanned += 1

        # Filter by severity
        threshold_level = self.SEVERITY_ORDER.get(self.severity_threshold, 3)
        filtered_vulns = [
            v for v in self.vulnerabilities
            if self.SEVERITY_ORDER.get(v.severity, 3) <= threshold_level
        ]

        end_time = datetime.now()
        scan_duration = (end_time - start_time).total_seconds()

        # Group by severity
        severity_counts = {}
        for vuln in filtered_vulns:
            severity_counts[vuln.severity] = severity_counts.get(vuln.severity, 0) + 1

        # Calculate risk score
        risk_score = self._calculate_risk_score(filtered_vulns)

        result = {
            "status": "completed",
            "target": str(self.target_path),
            "files_scanned": self.files_scanned,
            "packages_scanned": self.packages_scanned,
            "scan_duration_seconds": round(scan_duration, 2),
            "total_vulnerabilities": len(filtered_vulns),
            "risk_score": risk_score,
            "risk_level": self._get_risk_level(risk_score),
            "severity_counts": severity_counts,
            "vulnerabilities": [asdict(v) for v in filtered_vulns]
        }

        self._print_summary(result)

        return result

    def scan_npm(self, package_json_path: Path):
        """Scan package.json for npm vulnerabilities."""
        if self.verbose:
            print(f"  Scanning: {package_json_path}")

        try:
            with open(package_json_path, 'r') as f:
                data = json.load(f)

            deps = {}
            deps.update(data.get('dependencies', {}))
            deps.update(data.get('devDependencies', {}))

            for package, version_spec in deps.items():
                self.packages_scanned += 1
                version = self._normalize_version(version_spec)
                self._check_vulnerability(package.lower(), version, 'npm')

        except Exception as e:
            if self.verbose:
                print(f"  Error scanning {package_json_path}: {e}")

    def scan_npm_lock(self, package_lock_path: Path):
        """Scan package-lock.json for transitive dependencies."""
        if self.verbose:
            print(f"  Scanning: {package_lock_path}")

        try:
            with open(package_lock_path, 'r') as f:
                data = json.load(f)

            # Handle npm v2/v3 lockfile format
            packages = data.get('packages', {})
            if not packages:
                # npm v1 format
                packages = data.get('dependencies', {})

            for pkg_path, pkg_info in packages.items():
                if not pkg_path:  # Skip root
                    continue

                # Extract package name from path
                package = pkg_path.split('node_modules/')[-1]
                version = pkg_info.get('version', '')

                if package and version:
                    self.packages_scanned += 1
                    self._check_vulnerability(package.lower(), version, 'npm')

        except Exception as e:
            if self.verbose:
                print(f"  Error scanning {package_lock_path}: {e}")

    def scan_python(self, requirements_path: Path):
        """Scan requirements.txt for Python vulnerabilities."""
        if self.verbose:
            print(f"  Scanning: {requirements_path}")

        try:
            content = requirements_path.read_text()

            # Handle pyproject.toml
            if requirements_path.name == 'pyproject.toml':
                self._scan_pyproject(content)
                return

            # Parse requirements.txt
            for line in content.split('\n'):
                line = line.strip()
                if not line or line.startswith('#') or line.startswith('-'):
                    continue

                # Parse package==version or package>=version
                match = re.match(r'^([a-zA-Z0-9_-]+)\s*([=<>!~]+)\s*([0-9.]+)', line)
                if match:
                    package = match.group(1).lower()
                    version = match.group(3)
                    self.packages_scanned += 1
                    self._check_vulnerability(package, version, 'pypi')

        except Exception as e:
            if self.verbose:
                print(f"  Error scanning {requirements_path}: {e}")

    def _scan_pyproject(self, content: str):
        """Parse pyproject.toml for dependencies."""
        # Simple parsing - real implementation would use toml library
        in_deps = False
        for line in content.split('\n'):
            line = line.strip()
            if '[project.dependencies]' in line or '[tool.poetry.dependencies]' in line:
                in_deps = True
                continue
            if line.startswith('[') and in_deps:
                in_deps = False
                continue
            if in_deps and '=' in line:
                match = re.match(r'"?([a-zA-Z0-9_-]+)"?\s*[=:]\s*"?([^"]+)"?', line)
                if match:
                    package = match.group(1).lower()
                    version_spec = match.group(2)
                    version = self._normalize_version(version_spec)
                    self.packages_scanned += 1
                    self._check_vulnerability(package, version, 'pypi')

    def scan_go(self, go_mod_path: Path):
        """Scan go.mod for Go vulnerabilities."""
        if self.verbose:
            print(f"  Scanning: {go_mod_path}")

        try:
            content = go_mod_path.read_text()

            # Parse require blocks
            in_require = False
            for line in content.split('\n'):
                line = line.strip()

                if line.startswith('require ('):
                    in_require = True
                    continue
                if in_require and line == ')':
                    in_require = False
                    continue

                # Parse single require or block require
                if line.startswith('require ') or in_require:
                    parts = line.replace('require ', '').split()
                    if len(parts) >= 2:
                        package = parts[0]
                        version = parts[1]
                        self.packages_scanned += 1
                        self._check_vulnerability(package, version, 'go')

        except Exception as e:
            if self.verbose:
                print(f"  Error scanning {go_mod_path}: {e}")

    def _normalize_version(self, version_spec: str) -> str:
        """Extract version number from version specification."""
        # Remove prefixes like ^, ~, >=, etc.
        version = re.sub(r'^[\^~>=<]+', '', version_spec)
        # Remove suffixes like -alpha, -beta, etc.
        version = re.split(r'[-+]', version)[0]
        return version.strip()

    def _check_vulnerability(self, package: str, version: str, ecosystem: str):
        """Check if package version has known vulnerabilities."""
        cves = self.KNOWN_CVES.get(package, [])

        for cve_info in cves:
            if self._version_lt(version, cve_info['version_lt']):
                vuln = Vulnerability(
                    cve_id=cve_info['cve'],
                    package=package,
                    installed_version=version,
                    fixed_version=cve_info['fixed'],
                    severity=cve_info['severity'],
                    cvss_score=cve_info['cvss'],
                    description=cve_info['desc'],
                    ecosystem=ecosystem,
                    recommendation=f"Upgrade {package} to {cve_info['fixed']} or later"
                )
                # Avoid duplicates
                if not any(v.cve_id == vuln.cve_id and v.package == vuln.package
                          for v in self.vulnerabilities):
                    self.vulnerabilities.append(vuln)

    def _version_lt(self, version: str, threshold: str) -> bool:
        """Compare version strings (simplified)."""
        try:
            # Remove 'v' prefix for Go versions
            v1 = version.lstrip('v')
            v2 = threshold.lstrip('v')

            parts1 = [int(x) for x in re.split(r'[.\-]', v1) if x.isdigit()]
            parts2 = [int(x) for x in re.split(r'[.\-]', v2) if x.isdigit()]

            # Pad shorter version
            while len(parts1) < len(parts2):
                parts1.append(0)
            while len(parts2) < len(parts1):
                parts2.append(0)

            return parts1 < parts2
        except (ValueError, AttributeError):
            return False

    def _calculate_risk_score(self, vulnerabilities: List[Vulnerability]) -> float:
        """Calculate overall risk score (0-100)."""
        if not vulnerabilities:
            return 0.0

        # Weight by severity and CVSS
        severity_weights = {'critical': 4.0, 'high': 3.0, 'medium': 2.0, 'low': 1.0}
        total_weight = 0.0

        for vuln in vulnerabilities:
            weight = severity_weights.get(vuln.severity, 1.0)
            total_weight += (vuln.cvss_score * weight)

        # Normalize to 0-100
        max_possible = len(vulnerabilities) * 10.0 * 4.0
        score = (total_weight / max_possible) * 100 if max_possible > 0 else 0

        return min(100.0, round(score, 1))

    def _get_risk_level(self, score: float) -> str:
        """Get risk level from score."""
        if score >= 70:
            return "CRITICAL"
        elif score >= 50:
            return "HIGH"
        elif score >= 25:
            return "MEDIUM"
        elif score > 0:
            return "LOW"
        return "NONE"

    def _print_summary(self, result: Dict):
        """Print assessment summary."""
        print("\n" + "=" * 60)
        print("VULNERABILITY ASSESSMENT SUMMARY")
        print("=" * 60)
        print(f"Target: {result['target']}")
        print(f"Files scanned: {result['files_scanned']}")
        print(f"Packages scanned: {result['packages_scanned']}")
        print(f"Scan duration: {result['scan_duration_seconds']}s")
        print(f"Total vulnerabilities: {result['total_vulnerabilities']}")
        print(f"Risk score: {result['risk_score']}/100 ({result['risk_level']})")
        print()

        if result['severity_counts']:
            print("Vulnerabilities by severity:")
            for severity in ['critical', 'high', 'medium', 'low']:
                count = result['severity_counts'].get(severity, 0)
                if count > 0:
                    print(f"  {severity.upper()}: {count}")
        print("=" * 60)

        if result['total_vulnerabilities'] > 0:
            print("\nTop vulnerabilities:")
            # Sort by CVSS score
            sorted_vulns = sorted(
                result['vulnerabilities'],
                key=lambda x: x['cvss_score'],
                reverse=True
            )
            for vuln in sorted_vulns[:5]:
                print(f"\n  [{vuln['severity'].upper()}] {vuln['cve_id']}")
                print(f"  Package: {vuln['package']}@{vuln['installed_version']}")
                print(f"  CVSS: {vuln['cvss_score']}")
                print(f"  Fix: Upgrade to {vuln['fixed_version']}")


def main():
    """Main entry point for CLI."""
    parser = argparse.ArgumentParser(
        description="Scan dependencies for known vulnerabilities",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s /path/to/project
  %(prog)s /path/to/project --severity high
  %(prog)s /path/to/project --output report.json --json
  %(prog)s . --verbose
        """
    )

    parser.add_argument(
        "target",
        help="Directory containing dependency files"
    )
    parser.add_argument(
        "--severity", "-s",
        choices=["critical", "high", "medium", "low"],
        default="low",
        help="Minimum severity to report (default: low)"
    )
    parser.add_argument(
        "--verbose", "-v",
        action="store_true",
        help="Enable verbose output"
    )
    parser.add_argument(
        "--json",
        action="store_true",
        help="Output results as JSON"
    )
    parser.add_argument(
        "--output", "-o",
        help="Output file path"
    )

    args = parser.parse_args()

    assessor = VulnerabilityAssessor(
        target_path=args.target,
        severity_threshold=args.severity,
        verbose=args.verbose
    )

    result = assessor.assess()

    if args.json:
        output = json.dumps(result, indent=2)
        if args.output:
            with open(args.output, 'w') as f:
                f.write(output)
            print(f"\nResults written to {args.output}")
        else:
            print(output)
    elif args.output:
        with open(args.output, 'w') as f:
            json.dump(result, f, indent=2)
        print(f"\nResults written to {args.output}")

    # Exit with error code if critical/high vulnerabilities
    if result.get('severity_counts', {}).get('critical', 0) > 0:
        sys.exit(2)
    if result.get('severity_counts', {}).get('high', 0) > 0:
        sys.exit(1)


if __name__ == "__main__":
    main()
