import ast
import os
from pathlib import Path
from typing import Dict, List, Set, Optional
from dataclasses import dataclass, field
from collections import deque


# [20251213_FEATURE] v1.5.0 - Enhanced call graph with line numbers and Mermaid support


@dataclass
class CallNode:
    """A node in the call graph with location information."""

    name: str
    file: str
    line: int
    end_line: Optional[int] = None
    is_entry_point: bool = False


@dataclass
class CallEdge:
    """An edge in the call graph representing a call relationship."""

    caller: str
    callee: str
    call_line: Optional[int] = None


@dataclass
class CallGraphResult:
    """Result of call graph analysis."""

    nodes: List[CallNode] = field(default_factory=list)
    edges: List[CallEdge] = field(default_factory=list)
    entry_point: Optional[str] = None
    depth_limit: Optional[int] = None
    mermaid: str = ""


class CallGraphBuilder:
    """
    Builds a static call graph for a Python project.
    """

    def __init__(self, root_path: Path):
        self.root_path = root_path
        self.definitions: Dict[str, Set[str]] = (
            {}
        )  # file_path -> set of defined functions/classes
        self.calls: Dict[str, List[str]] = (
            {}
        )  # "file:function" -> list of called function names
        self.imports: Dict[str, Dict[str, str]] = (
            {}
        )  # file_path -> { alias -> full_name }

    def build(self) -> Dict[str, List[str]]:
        """
        Build the call graph.
        Returns an adjacency list: {"module:caller": ["module:callee", ...]}
        """
        # 1. First pass: Collect definitions and imports
        for file_path in self._iter_python_files():
            rel_path = str(file_path.relative_to(self.root_path))
            try:
                with open(file_path, "r", encoding="utf-8") as f:
                    code = f.read()
                tree = ast.parse(code)
                self._analyze_definitions(tree, rel_path)
            except Exception:
                continue

        # 2. Second pass: Analyze calls and resolve them
        graph = {}
        for file_path in self._iter_python_files():
            rel_path = str(file_path.relative_to(self.root_path))
            try:
                with open(file_path, "r", encoding="utf-8") as f:
                    code = f.read()
                tree = ast.parse(code)
                file_calls = self._analyze_calls(tree, rel_path)
                graph.update(file_calls)
            except Exception:
                continue

        return graph

    def _iter_python_files(self):
        """Iterate over all Python files in the project, skipping hidden/ignored dirs."""
        skip_dirs = {
            ".git",
            ".venv",
            "venv",
            "__pycache__",
            "node_modules",
            "dist",
            "build",
        }
        for root, dirs, files in os.walk(self.root_path):
            dirs[:] = [d for d in dirs if d not in skip_dirs and not d.startswith(".")]
            for file in files:
                if file.endswith(".py"):
                    yield Path(root) / file

    def _analyze_definitions(self, tree: ast.AST, rel_path: str):
        """Extract function/class definitions and imports."""
        self.definitions[rel_path] = set()
        self.imports[rel_path] = {}

        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                self.definitions[rel_path].add(node.name)
            elif isinstance(node, ast.ClassDef):
                self.definitions[rel_path].add(node.name)
                # Also add methods
                for item in node.body:
                    if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
                        self.definitions[rel_path].add(f"{node.name}.{item.name}")

            # Collect imports for resolution
            elif isinstance(node, ast.Import):
                for alias in node.names:
                    name = alias.name
                    asname = alias.asname or name
                    self.imports[rel_path][asname] = name
            elif isinstance(node, ast.ImportFrom):
                module = node.module or ""
                for alias in node.names:
                    name = alias.name
                    asname = alias.asname or name
                    full_name = f"{module}.{name}" if module else name
                    self.imports[rel_path][asname] = full_name

    def _analyze_calls(self, tree: ast.AST, rel_path: str) -> Dict[str, List[str]]:
        """Extract calls from functions and resolve them."""
        file_graph = {}

        class CallVisitor(ast.NodeVisitor):
            def __init__(self, builder, current_file):
                self.builder = builder
                self.current_file = current_file
                self.current_scope = None
                self.calls = []

            def visit_FunctionDef(self, node):
                old_scope = self.current_scope
                self.current_scope = node.name
                self.calls = []
                self.generic_visit(node)

                # Store calls for this function
                key = f"{self.current_file}:{node.name}"
                file_graph[key] = self.calls

                self.current_scope = old_scope

            def visit_AsyncFunctionDef(self, node):
                self.visit_FunctionDef(node)

            def visit_Call(self, node):
                if self.current_scope:
                    callee = self._get_callee_name(node)
                    if callee:
                        resolved = self._resolve_callee(callee)
                        self.calls.append(resolved)
                self.generic_visit(node)

            def _get_callee_name(self, node):
                if isinstance(node.func, ast.Name):
                    return node.func.id
                elif isinstance(node.func, ast.Attribute):
                    # Handle obj.method() - simplified
                    value = self._get_attribute_value(node.func.value)
                    if value:
                        return f"{value}.{node.func.attr}"
                return None

            def _get_attribute_value(self, node):
                if isinstance(node, ast.Name):
                    return node.id
                elif isinstance(node, ast.Attribute):
                    val = self._get_attribute_value(node.value)
                    return f"{val}.{node.attr}" if val else None
                return None

            def _resolve_callee(self, callee):
                # 1. Check local imports
                imports = self.builder.imports.get(self.current_file, {})

                # Case: alias.method() where alias is imported
                parts = callee.split(".")
                if parts[0] in imports:
                    # e.g. "utils.hash" where "import my_utils as utils" -> "my_utils.hash"
                    # or "hash" where "from utils import hash" -> "utils.hash"
                    resolved_base = imports[parts[0]]
                    if len(parts) > 1:
                        return f"{resolved_base}.{'.'.join(parts[1:])}"
                    return resolved_base

                # 2. Check if it's a local definition in the same file
                if callee in self.builder.definitions.get(self.current_file, set()):
                    return f"{self.current_file}:{callee}"

                # 3. Fallback: return as is (likely external lib or built-in)
                return callee

        visitor = CallVisitor(self, rel_path)
        visitor.visit(tree)
        return file_graph

    # [20251213_FEATURE] v1.5.0 - Enhanced call graph methods

    def build_with_details(
        self,
        entry_point: Optional[str] = None,
        depth: int = 10,
    ) -> CallGraphResult:
        """
        Build call graph with detailed node and edge information.

        Args:
            entry_point: Starting function (e.g., "main" or "src/app.py:main")
                        If None, includes all functions
            depth: Maximum depth to traverse from entry point (default: 10)

        Returns:
            CallGraphResult with nodes, edges, and Mermaid diagram
        """
        # Build the base graph first
        base_graph = self.build()

        # Collect node information with line numbers
        node_info: Dict[str, CallNode] = {}
        for file_path in self._iter_python_files():
            rel_path = str(file_path.relative_to(self.root_path))
            try:
                with open(file_path, "r", encoding="utf-8") as f:
                    code = f.read()
                tree = ast.parse(code)

                for node in ast.walk(tree):
                    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                        key = f"{rel_path}:{node.name}"
                        node_info[key] = CallNode(
                            name=node.name,
                            file=rel_path,
                            line=node.lineno,
                            end_line=getattr(node, "end_lineno", None),
                            is_entry_point=self._is_entry_point(node, tree),
                        )
            except Exception:
                continue

        # If entry_point is specified, filter to reachable nodes
        if entry_point:
            reachable = self._get_reachable_nodes(base_graph, entry_point, depth)
        else:
            reachable = set(base_graph.keys())

        # Build nodes list
        nodes = []
        seen_nodes = set()
        for key in reachable:
            if key in node_info:
                nodes.append(node_info[key])
                seen_nodes.add(key)
            else:
                # External or built-in function
                nodes.append(
                    CallNode(
                        name=key.split(":")[-1] if ":" in key else key,
                        file="<external>" if ":" not in key else key.split(":")[0],
                        line=0,
                    )
                )
                seen_nodes.add(key)

        # Build edges list
        edges = []
        for caller, callees in base_graph.items():
            if caller not in reachable:
                continue
            for callee in callees:
                # Include edge if both ends are in reachable set (or external)
                edges.append(CallEdge(caller=caller, callee=callee))

        # Generate Mermaid diagram
        mermaid = self._generate_mermaid(nodes, edges, entry_point)

        return CallGraphResult(
            nodes=nodes,
            edges=edges,
            entry_point=entry_point,
            depth_limit=depth,
            mermaid=mermaid,
        )

    def _is_entry_point(self, func_node: ast.AST, tree: ast.AST) -> bool:
        """
        Detect if a function is likely an entry point.

        Entry point heuristics:
        - Function named "main"
        - Function decorated with CLI decorators (click.command, etc.)
        - Function called in if __name__ == "__main__" block
        """
        if func_node.name == "main":
            return True

        # Check for CLI decorators
        for decorator in getattr(func_node, "decorator_list", []):
            dec_name = ""
            if isinstance(decorator, ast.Name):
                dec_name = decorator.id
            elif isinstance(decorator, ast.Attribute):
                dec_name = decorator.attr
            elif isinstance(decorator, ast.Call):
                if isinstance(decorator.func, ast.Attribute):
                    dec_name = decorator.func.attr
                elif isinstance(decorator.func, ast.Name):
                    dec_name = decorator.func.id

            if dec_name in ("command", "main", "cli", "app", "route", "get", "post"):
                return True

        return False

    def _get_reachable_nodes(
        self,
        graph: Dict[str, List[str]],
        entry_point: str,
        max_depth: int,
    ) -> Set[str]:
        """
        Get all nodes reachable from entry point within depth limit using BFS.
        Handles recursive calls gracefully.
        """
        # Normalize entry point (might be just "main" or "file:main")
        if ":" not in entry_point:
            # Find the full key
            for key in graph.keys():
                if key.endswith(f":{entry_point}"):
                    entry_point = key
                    break

        reachable = set()
        queue = deque([(entry_point, 0)])

        while queue:
            node, depth = queue.popleft()

            if node in reachable:
                continue  # Already visited (handles cycles)

            reachable.add(node)

            if depth >= max_depth:
                continue  # Don't explore further

            # Add callees to queue
            for callee in graph.get(node, []):
                if callee not in reachable:
                    queue.append((callee, depth + 1))

        return reachable

    def _generate_mermaid(
        self,
        nodes: List[CallNode],
        edges: List[CallEdge],
        entry_point: Optional[str],
    ) -> str:
        """Generate Mermaid flowchart diagram."""
        lines = ["graph TD"]

        # Create node ID mapping (Mermaid doesn't like special chars)
        node_ids: Dict[str, str] = {}
        for i, node in enumerate(nodes):
            full_name = (
                f"{node.file}:{node.name}" if node.file != "<external>" else node.name
            )
            node_id = f"N{i}"
            node_ids[full_name] = node_id
            # Also map short names for external refs
            node_ids[node.name] = node_ids.get(node.name, node_id)

        # Add nodes with labels
        for i, node in enumerate(nodes):
            node_id = f"N{i}"
            label = node.name
            if node.line > 0:
                label = f"{node.name}:L{node.line}"

            # Style entry points differently
            if node.is_entry_point or (
                entry_point and entry_point.endswith(f":{node.name}")
            ):
                lines.append(f'    {node_id}[["{label}"]]')  # Stadium shape for entry
            elif node.file == "<external>":
                lines.append(f"    {node_id}({label})")  # Round for external
            else:
                lines.append(f'    {node_id}["{label}"]')  # Rectangle for internal

        # Add edges
        for edge in edges:
            caller_id = node_ids.get(edge.caller) or node_ids.get(
                edge.caller.split(":")[-1]
            )
            callee_id = node_ids.get(edge.callee) or node_ids.get(
                edge.callee.split(":")[-1]
            )

            if caller_id and callee_id:
                lines.append(f"    {caller_id} --> {callee_id}")

        return "\n".join(lines)

    def detect_circular_imports(self) -> List[List[str]]:
        """
        Detect circular import cycles in the project.

        [20251213_FEATURE] P1 - Circular dependency detection

        Returns:
            List of cycles, where each cycle is a list of module paths
            e.g., [["a.py", "b.py", "a.py"], ["c.py", "d.py", "e.py", "c.py"]]
        """
        # Build import graph: module -> modules it imports
        import_graph: Dict[str, Set[str]] = {}

        for file_path in self._iter_python_files():
            rel_path = str(file_path.relative_to(self.root_path))
            try:
                with open(file_path, "r", encoding="utf-8") as f:
                    code = f.read()
                tree = ast.parse(code)

                imports = set()
                for node in ast.walk(tree):
                    if isinstance(node, ast.Import):
                        for alias in node.names:
                            # Convert module name to potential file path
                            mod_path = alias.name.replace(".", "/") + ".py"
                            imports.add(mod_path)
                    elif isinstance(node, ast.ImportFrom):
                        if node.module:
                            mod_path = node.module.replace(".", "/") + ".py"
                            imports.add(mod_path)

                import_graph[rel_path] = imports
            except Exception:
                continue

        # Find cycles using DFS with coloring
        cycles = []
        WHITE, GRAY, BLACK = 0, 1, 2
        color = {node: WHITE for node in import_graph}
        path = []

        def dfs(node: str):
            if color[node] == GRAY:
                # Found a cycle - extract it
                cycle_start = path.index(node)
                cycle = path[cycle_start:] + [node]
                cycles.append(cycle)
                return

            if color[node] == BLACK:
                return

            color[node] = GRAY
            path.append(node)

            for neighbor in import_graph.get(node, set()):
                if neighbor in import_graph:  # Only follow internal imports
                    dfs(neighbor)

            path.pop()
            color[node] = BLACK

        for node in import_graph:
            if color[node] == WHITE:
                dfs(node)

        return cycles
