"""
Vulnerability Scanner for A06:2021 - Vulnerable and Outdated Components.

This module scans dependency files (package.json, pom.xml, build.gradle, requirements.txt)
and checks them against the Google OSV (Open Source Vulnerabilities) database.

[20251219_FEATURE] v3.0.4 - A06 Vulnerable Components detection
"""

import json
import re
import xml.etree.ElementTree as ET
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any

import httpx

# OSV API endpoint
OSV_API_URL = "https://api.osv.dev/v1/query"
OSV_BATCH_URL = "https://api.osv.dev/v1/querybatch"


class Ecosystem(str, Enum):
    """Supported package ecosystems."""
    NPM = "npm"
    MAVEN = "Maven"
    PYPI = "PyPI"
    GO = "Go"


@dataclass
class Dependency:
    """Represents a project dependency."""
    name: str
    version: str
    ecosystem: Ecosystem
    source_file: str
    is_dev: bool = False


@dataclass
class VulnerabilityFinding:
    """Represents a vulnerability found in a dependency."""
    id: str  # OSV ID (e.g., GHSA-xxxx-xxxx-xxxx)
    aliases: list[str] = field(default_factory=list)  # CVE IDs
    summary: str = ""
    details: str = ""
    severity: str = "UNKNOWN"
    package_name: str = ""
    package_version: str = ""
    ecosystem: str = ""
    source_file: str = ""
    fixed_versions: list[str] = field(default_factory=list)
    references: list[str] = field(default_factory=list)

    @property
    def cve_id(self) -> str | None:
        """Return the first CVE ID if available."""
        for alias in self.aliases:
            if alias.startswith("CVE-"):
                return alias
        return None


@dataclass
class ScanResult:
    """Result of a vulnerability scan."""
    dependencies_scanned: int
    vulnerabilities_found: int
    findings: list[VulnerabilityFinding]
    errors: list[str] = field(default_factory=list)


class DependencyParser:
    """Parses dependency files to extract package information."""

    @staticmethod
    def parse_package_json(file_path: Path) -> list[Dependency]:
        """Parse npm package.json file."""
        deps = []
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                data = json.load(f)

            # Production dependencies
            for name, version in data.get("dependencies", {}).items():
                clean_version = DependencyParser._clean_npm_version(version)
                if clean_version:
                    deps.append(Dependency(
                        name=name,
                        version=clean_version,
                        ecosystem=Ecosystem.NPM,
                        source_file=str(file_path),
                        is_dev=False
                    ))

            # Dev dependencies
            for name, version in data.get("devDependencies", {}).items():
                clean_version = DependencyParser._clean_npm_version(version)
                if clean_version:
                    deps.append(Dependency(
                        name=name,
                        version=clean_version,
                        ecosystem=Ecosystem.NPM,
                        source_file=str(file_path),
                        is_dev=True
                    ))

        except (json.JSONDecodeError, OSError) as e:
            raise ValueError(f"Failed to parse {file_path}: {e}")

        return deps

    @staticmethod
    def _clean_npm_version(version: str) -> str | None:
        """Clean npm version string (remove ^, ~, >=, etc.)."""
        if not version:
            return None
        # Handle workspace:, file:, git:, link: protocols
        if any(version.startswith(p) for p in ["workspace:", "file:", "git:", "link:", "http:", "https:"]):
            return None
        # Remove version prefixes
        cleaned = re.sub(r'^[\^~>=<]*', '', version)
        # Handle version ranges like "1.0.0 - 2.0.0" - take the first
        if " - " in cleaned:
            cleaned = cleaned.split(" - ")[0]
        # Handle || ranges - take the first
        if " || " in cleaned:
            cleaned = cleaned.split(" || ")[0].strip()
        # Validate it looks like a version
        if re.match(r'^\d+', cleaned):
            return cleaned.strip()
        return None

    @staticmethod
    def parse_pom_xml(file_path: Path) -> list[Dependency]:
        """Parse Maven pom.xml file."""
        deps = []
        try:
            tree = ET.parse(file_path)
            root = tree.getroot()

            # Handle Maven namespace
            ns = {"m": "http://maven.apache.org/POM/4.0.0"}
            ns_prefix = "{http://maven.apache.org/POM/4.0.0}"

            # Check if namespace is used
            if root.tag.startswith(ns_prefix):
                dep_elements = root.findall(".//m:dependency", ns)
            else:
                dep_elements = root.findall(".//dependency")

            for dep in dep_elements:
                if root.tag.startswith(ns_prefix):
                    group_id = dep.find("m:groupId", ns)
                    artifact_id = dep.find("m:artifactId", ns)
                    version = dep.find("m:version", ns)
                    scope = dep.find("m:scope", ns)
                else:
                    group_id = dep.find("groupId")
                    artifact_id = dep.find("artifactId")
                    version = dep.find("version")
                    scope = dep.find("scope")

                if group_id is not None and artifact_id is not None and version is not None:
                    # Skip property references like ${project.version}
                    version_text = version.text or ""
                    if version_text.startswith("${"):
                        continue

                    is_dev = scope is not None and scope.text in ("test", "provided")
                    deps.append(Dependency(
                        name=f"{group_id.text}:{artifact_id.text}",
                        version=version_text,
                        ecosystem=Ecosystem.MAVEN,
                        source_file=str(file_path),
                        is_dev=is_dev
                    ))

        except (ET.ParseError, OSError) as e:
            raise ValueError(f"Failed to parse {file_path}: {e}")

        return deps

    @staticmethod
    def parse_build_gradle(file_path: Path) -> list[Dependency]:
        """Parse Gradle build.gradle file (basic parsing)."""
        deps = []
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                content = f.read()

            # Match patterns like: implementation 'group:artifact:version'
            # or: implementation "group:artifact:version"
            patterns = [
                r"(?:implementation|api|compile|testImplementation|testCompile|runtimeOnly|compileOnly)\s*['\"]([^'\"]+):([^'\"]+):([^'\"]+)['\"]",
                r"(?:implementation|api|compile|testImplementation|testCompile|runtimeOnly|compileOnly)\s+group:\s*['\"]([^'\"]+)['\"],\s*name:\s*['\"]([^'\"]+)['\"],\s*version:\s*['\"]([^'\"]+)['\"]",
            ]

            for pattern in patterns:
                for match in re.finditer(pattern, content):
                    group_id, artifact_id, version = match.groups()
                    is_dev = "test" in match.group(0).lower()
                    deps.append(Dependency(
                        name=f"{group_id}:{artifact_id}",
                        version=version,
                        ecosystem=Ecosystem.MAVEN,
                        source_file=str(file_path),
                        is_dev=is_dev
                    ))

        except OSError as e:
            raise ValueError(f"Failed to parse {file_path}: {e}")

        return deps

    @staticmethod
    def parse_requirements_txt(file_path: Path) -> list[Dependency]:
        """Parse Python requirements.txt file."""
        deps = []
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    # Skip comments and empty lines
                    if not line or line.startswith("#") or line.startswith("-"):
                        continue

                    # Parse package==version or package>=version
                    match = re.match(r'^([a-zA-Z0-9_-]+)\s*([=<>!~]+)\s*([^\s;#]+)', line)
                    if match:
                        name, _, version = match.groups()
                        deps.append(Dependency(
                            name=name.lower(),
                            version=version,
                            ecosystem=Ecosystem.PYPI,
                            source_file=str(file_path),
                            is_dev=False
                        ))

        except OSError as e:
            raise ValueError(f"Failed to parse {file_path}: {e}")

        return deps

    @staticmethod
    def parse_pyproject_toml(file_path: Path) -> list[Dependency]:
        """Parse Python pyproject.toml file (basic parsing)."""
        deps = []
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                content = f.read()

            # Basic regex parsing for dependencies
            # Matches: "package>=1.0.0" or 'package==1.0.0' in dependencies section
            in_deps = False
            for line in content.split('\n'):
                if 'dependencies' in line and '=' in line:
                    in_deps = True
                    continue
                if in_deps:
                    if line.strip().startswith(']'):
                        in_deps = False
                        continue
                    match = re.search(r'["\']([a-zA-Z0-9_-]+)\s*([=<>!~]+)\s*([^"\']+)["\']', line)
                    if match:
                        name, _, version = match.groups()
                        deps.append(Dependency(
                            name=name.lower(),
                            version=version.strip(),
                            ecosystem=Ecosystem.PYPI,
                            source_file=str(file_path),
                            is_dev=False
                        ))

        except OSError as e:
            raise ValueError(f"Failed to parse {file_path}: {e}")

        return deps


class OSVClient:
    """Client for the Google OSV (Open Source Vulnerabilities) API."""

    def __init__(self, timeout: float = 30.0):
        self.timeout = timeout
        self._client: httpx.Client | None = None

    def _get_client(self) -> httpx.Client:
        if self._client is None:
            self._client = httpx.Client(timeout=self.timeout)
        return self._client

    def query_single(self, package_name: str, version: str, ecosystem: str) -> list[dict[str, Any]]:
        """Query OSV for vulnerabilities in a single package."""
        client = self._get_client()
        try:
            response = client.post(
                OSV_API_URL,
                json={
                    "package": {
                        "name": package_name,
                        "ecosystem": ecosystem
                    },
                    "version": version
                }
            )
            response.raise_for_status()
            data = response.json()
            return data.get("vulns", [])
        except httpx.HTTPError:
            return []

    def query_batch(self, dependencies: list[Dependency]) -> dict[str, list[dict[str, Any]]]:
        """Query OSV for vulnerabilities in multiple packages (batch API)."""
        if not dependencies:
            return {}

        client = self._get_client()
        queries = []
        for dep in dependencies:
            queries.append({
                "package": {
                    "name": dep.name,
                    "ecosystem": dep.ecosystem.value
                },
                "version": dep.version
            })

        try:
            response = client.post(
                OSV_BATCH_URL,
                json={"queries": queries}
            )
            response.raise_for_status()
            data = response.json()

            results = {}
            for i, result in enumerate(data.get("results", [])):
                if result.get("vulns"):
                    key = f"{dependencies[i].name}@{dependencies[i].version}"
                    results[key] = result["vulns"]
            return results
        except httpx.HTTPError:
            return {}

    def close(self):
        """Close the HTTP client."""
        if self._client:
            self._client.close()
            self._client = None


class VulnerabilityScanner:
    """
    Scans project dependencies for known vulnerabilities.
    
    Supports:
    - npm (package.json)
    - Maven (pom.xml, build.gradle)
    - PyPI (requirements.txt, pyproject.toml)
    """

    def __init__(self, timeout: float = 30.0):
        self.osv_client = OSVClient(timeout=timeout)
        self.parser = DependencyParser()

    def scan_file(self, file_path: str | Path) -> ScanResult:
        """Scan a single dependency file for vulnerabilities."""
        file_path = Path(file_path)
        if not file_path.exists():
            return ScanResult(
                dependencies_scanned=0,
                vulnerabilities_found=0,
                findings=[],
                errors=[f"File not found: {file_path}"]
            )

        # Parse dependencies based on file type
        try:
            deps = self._parse_file(file_path)
        except ValueError as e:
            return ScanResult(
                dependencies_scanned=0,
                vulnerabilities_found=0,
                findings=[],
                errors=[str(e)]
            )

        if not deps:
            return ScanResult(
                dependencies_scanned=0,
                vulnerabilities_found=0,
                findings=[],
                errors=[]
            )

        # Query OSV for vulnerabilities
        return self._scan_dependencies(deps)

    def scan_directory(self, dir_path: str | Path, recursive: bool = True) -> ScanResult:
        """Scan a directory for dependency files and check for vulnerabilities."""
        dir_path = Path(dir_path)
        if not dir_path.is_dir():
            return ScanResult(
                dependencies_scanned=0,
                vulnerabilities_found=0,
                findings=[],
                errors=[f"Not a directory: {dir_path}"]
            )

        # Find all dependency files
        patterns = [
            "package.json",
            "pom.xml",
            "build.gradle",
            "requirements.txt",
            "pyproject.toml"
        ]

        all_deps: list[Dependency] = []
        errors: list[str] = []

        for pattern in patterns:
            if recursive:
                files = list(dir_path.rglob(pattern))
            else:
                files = list(dir_path.glob(pattern))

            for file_path in files:
                # Skip node_modules and other dependency directories
                if "node_modules" in str(file_path) or ".venv" in str(file_path):
                    continue
                try:
                    deps = self._parse_file(file_path)
                    all_deps.extend(deps)
                except ValueError as e:
                    errors.append(str(e))

        if not all_deps:
            return ScanResult(
                dependencies_scanned=0,
                vulnerabilities_found=0,
                findings=[],
                errors=errors
            )

        result = self._scan_dependencies(all_deps)
        result.errors.extend(errors)
        return result

    def _parse_file(self, file_path: Path) -> list[Dependency]:
        """Parse a dependency file based on its name."""
        name = file_path.name.lower()

        if name == "package.json":
            return self.parser.parse_package_json(file_path)
        elif name == "pom.xml":
            return self.parser.parse_pom_xml(file_path)
        elif name == "build.gradle":
            return self.parser.parse_build_gradle(file_path)
        elif name == "requirements.txt":
            return self.parser.parse_requirements_txt(file_path)
        elif name == "pyproject.toml":
            return self.parser.parse_pyproject_toml(file_path)
        else:
            raise ValueError(f"Unsupported file type: {file_path}")

    def _scan_dependencies(self, deps: list[Dependency]) -> ScanResult:
        """Scan a list of dependencies for vulnerabilities."""
        findings: list[VulnerabilityFinding] = []

        # Use batch API for efficiency
        vuln_map = self.osv_client.query_batch(deps)

        for dep in deps:
            key = f"{dep.name}@{dep.version}"
            vulns = vuln_map.get(key, [])

            for vuln in vulns:
                finding = self._parse_vulnerability(vuln, dep)
                findings.append(finding)

        return ScanResult(
            dependencies_scanned=len(deps),
            vulnerabilities_found=len(findings),
            findings=findings,
            errors=[]
        )

    def _parse_vulnerability(self, vuln: dict[str, Any], dep: Dependency) -> VulnerabilityFinding:
        """Parse an OSV vulnerability into a VulnerabilityFinding."""
        # Extract severity
        severity = "UNKNOWN"
        if "severity" in vuln:
            for sev in vuln["severity"]:
                if sev.get("type") == "CVSS_V3":
                    score = sev.get("score", "")
                    # Parse CVSS score from vector or direct score
                    if isinstance(score, (int, float)):
                        if score >= 9.0:
                            severity = "CRITICAL"
                        elif score >= 7.0:
                            severity = "HIGH"
                        elif score >= 4.0:
                            severity = "MEDIUM"
                        else:
                            severity = "LOW"
                    break
        elif "database_specific" in vuln:
            db_severity = vuln["database_specific"].get("severity")
            if db_severity:
                severity = db_severity.upper()

        # Extract fixed versions
        fixed_versions = []
        for affected in vuln.get("affected", []):
            for range_info in affected.get("ranges", []):
                for event in range_info.get("events", []):
                    if "fixed" in event:
                        fixed_versions.append(event["fixed"])

        # Extract references
        references = [ref.get("url", "") for ref in vuln.get("references", []) if ref.get("url")]

        return VulnerabilityFinding(
            id=vuln.get("id", "UNKNOWN"),
            aliases=vuln.get("aliases", []),
            summary=vuln.get("summary", ""),
            details=vuln.get("details", "")[:500],  # Truncate long details
            severity=severity,
            package_name=dep.name,
            package_version=dep.version,
            ecosystem=dep.ecosystem.value,
            source_file=dep.source_file,
            fixed_versions=fixed_versions,
            references=references[:5]  # Limit references
        )

    def close(self):
        """Close the scanner and release resources."""
        self.osv_client.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


def scan_dependencies(path: str | Path) -> ScanResult:
    """
    Convenience function to scan a file or directory for vulnerable dependencies.
    
    Args:
        path: Path to a dependency file or directory containing dependency files
        
    Returns:
        ScanResult with findings
    """
    path = Path(path)
    with VulnerabilityScanner() as scanner:
        if path.is_file():
            return scanner.scan_file(path)
        else:
            return scanner.scan_directory(path)
