"""Test Generator - Convert Symbolic Execution Results to Unit Tests.

This module converts the mathematical proofs from Z3 symbolic execution
into executable unit tests. Each path through the code gets a test case
with concrete inputs that trigger that specific path.

Example:
    >>> from code_scalpel.generators import TestGenerator
    >>> generator = TestGenerator()
    >>> result = generator.generate(code, function_name="classify")
    >>> print(result.pytest_code)
    # Generated pytest tests with concrete inputs for each path
"""

import ast
import re
from dataclasses import dataclass, field
from typing import Any


@dataclass
class TestCase:
    """A single generated test case."""

    path_id: int
    function_name: str
    inputs: dict[str, Any]
    expected_behavior: str
    path_conditions: list[str]
    description: str

    def to_pytest(self, index: int) -> str:
        """Convert test case to pytest function."""
        lines = [
            f"def test_{self.function_name}_path_{self.path_id}():",
            f'    """',
            f"    Path {self.path_id}: {self.description}",
            f"    Conditions: {', '.join(self.path_conditions) or 'No branches'}",
            f'    """',
        ]

        # Setup inputs
        if self.inputs:
            for var, value in self.inputs.items():
                lines.append(f"    {var} = {repr(value)}")
            lines.append("")

        # Call function with inputs
        args = ", ".join(f"{k}={k}" for k in self.inputs.keys())
        lines.append(f"    result = {self.function_name}({args})")
        lines.append("")

        # Assert (we can't know expected output without execution, but we verify no crash)
        lines.append("    # Path is reachable with these inputs")
        lines.append("    assert result is not None or result is None  # Executed successfully")

        return "\n".join(lines)


@dataclass
class GeneratedTestSuite:
    """A complete generated test suite."""

    function_name: str
    test_cases: list[TestCase]
    source_code: str
    language: str = "python"
    framework: str = "pytest"

    @property
    def pytest_code(self) -> str:
        """Generate complete pytest file content."""
        lines = [
            '"""',
            f"Auto-generated tests for {self.function_name}",
            "",
            "Generated by Code Scalpel Test Generator using symbolic execution.",
            "Each test case represents a unique execution path through the code.",
            '"""',
            "",
            "import pytest",
            "",
            "# Original function under test",
            self._extract_function_code(),
            "",
            "",
            "# Generated test cases",
        ]

        for i, test_case in enumerate(self.test_cases):
            lines.append(test_case.to_pytest(i))
            lines.append("")
            lines.append("")

        return "\n".join(lines)

    @property
    def unittest_code(self) -> str:
        """Generate complete unittest file content."""
        lines = [
            '"""',
            f"Auto-generated tests for {self.function_name}",
            "",
            "Generated by Code Scalpel Test Generator using symbolic execution.",
            '"""',
            "",
            "import unittest",
            "",
            "# Original function under test",
            self._extract_function_code(),
            "",
            "",
            f"class Test{self._camel_case(self.function_name)}(unittest.TestCase):",
        ]

        for i, test_case in enumerate(self.test_cases):
            lines.append(self._to_unittest_method(test_case, i))
            lines.append("")

        lines.extend([
            "",
            'if __name__ == "__main__":',
            "    unittest.main()",
        ])

        return "\n".join(lines)

    def _extract_function_code(self) -> str:
        """Extract just the target function from source code."""
        try:
            tree = ast.parse(self.source_code)
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef) and node.name == self.function_name:
                    return ast.unparse(node)
        except Exception:
            pass
        return f"# Function {self.function_name} - include from source"

    def _to_unittest_method(self, test_case: TestCase, index: int) -> str:
        """Convert test case to unittest method."""
        lines = [
            f"    def test_path_{test_case.path_id}(self):",
            f'        """Path {test_case.path_id}: {test_case.description}"""',
        ]

        if test_case.inputs:
            for var, value in test_case.inputs.items():
                lines.append(f"        {var} = {repr(value)}")

        args = ", ".join(f"{k}={k}" for k in test_case.inputs.keys())
        lines.append(f"        result = {self.function_name}({args})")
        lines.append("        # Verify path is reachable")
        lines.append("        self.assertTrue(True)  # Path executed successfully")

        return "\n".join(lines)

    @staticmethod
    def _camel_case(name: str) -> str:
        """Convert snake_case to CamelCase."""
        return "".join(word.capitalize() for word in name.split("_"))


class TestGenerator:
    """Generate unit tests from symbolic execution results.

    This generator takes the output of symbolic execution (paths with
    concrete input values) and produces executable test code.

    Supported frameworks:
    - pytest (default)
    - unittest

    Supported languages:
    - Python (full support)
    - JavaScript (planned)
    - Java (planned)
    """

    def __init__(self, framework: str = "pytest"):
        """Initialize the test generator.

        Args:
            framework: Test framework to generate for ("pytest" or "unittest")
        """
        if framework not in ("pytest", "unittest"):
            raise ValueError(f"Unsupported framework: {framework}")
        self.framework = framework

    def generate(
        self,
        code: str,
        function_name: str | None = None,
        symbolic_result: dict[str, Any] | None = None,
        language: str = "python",
    ) -> GeneratedTestSuite:
        """Generate test suite from code.

        Args:
            code: Source code to generate tests for
            function_name: Name of function to test (auto-detected if None)
            symbolic_result: Pre-computed symbolic execution result (optional)
            language: Source language ("python", "javascript", "java")

        Returns:
            GeneratedTestSuite with test cases for each execution path
        """
        # Auto-detect function name if not provided
        if function_name is None:
            function_name = self._detect_main_function(code, language)

        # Run symbolic execution if result not provided
        if symbolic_result is None:
            symbolic_result = self._run_symbolic_execution(code, language)

        # Extract test cases from paths
        test_cases = self._extract_test_cases(
            symbolic_result, function_name, code, language
        )

        return GeneratedTestSuite(
            function_name=function_name,
            test_cases=test_cases,
            source_code=code,
            language=language,
            framework=self.framework,
        )

    def generate_from_symbolic_result(
        self,
        symbolic_result: dict[str, Any],
        code: str,
        function_name: str,
        language: str = "python",
    ) -> GeneratedTestSuite:
        """Generate tests directly from a SymbolicResult dict.

        Args:
            symbolic_result: Dict with paths, symbolic_variables, constraints
            code: Original source code
            function_name: Name of function being tested
            language: Source language

        Returns:
            GeneratedTestSuite with test cases
        """
        test_cases = self._extract_test_cases(
            symbolic_result, function_name, code, language
        )

        return GeneratedTestSuite(
            function_name=function_name,
            test_cases=test_cases,
            source_code=code,
            language=language,
            framework=self.framework,
        )

    def _detect_main_function(self, code: str, language: str) -> str:
        """Detect the main function to test."""
        if language == "python":
            try:
                tree = ast.parse(code)
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        # Skip private/dunder methods
                        if not node.name.startswith("_"):
                            return node.name
            except SyntaxError:
                pass
        elif language == "javascript":
            # Simple regex for JS function detection
            match = re.search(r"function\s+(\w+)\s*\(", code)
            if match:
                return match.group(1)
            # Arrow function
            match = re.search(r"const\s+(\w+)\s*=\s*(?:async\s*)?\(", code)
            if match:
                return match.group(1)
        elif language == "java":
            # Java method detection
            match = re.search(r"(?:public|private|protected)?\s*\w+\s+(\w+)\s*\(", code)
            if match:
                return match.group(1)

        return "target_function"

    def _run_symbolic_execution(self, code: str, language: str) -> dict[str, Any]:
        """Run symbolic execution on the code."""
        try:
            from code_scalpel.symbolic import SymbolicExecutor

            executor = SymbolicExecutor(max_iterations=10)
            return executor.execute(code)
        except ImportError:
            # Fallback to basic path analysis
            return self._basic_path_analysis(code, language)

    def _basic_path_analysis(self, code: str, language: str) -> dict[str, Any]:
        """Basic path analysis fallback when symbolic execution unavailable."""
        paths = []
        symbolic_vars = []
        constraints = []

        if language == "python":
            try:
                tree = ast.parse(code)

                # Find function parameters (symbolic variables)
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        symbolic_vars = [arg.arg for arg in node.args.args]
                        break

                # Find branch conditions
                path_id = 0
                for node in ast.walk(tree):
                    if isinstance(node, ast.If):
                        condition = ast.unparse(node.test)
                        constraints.append(condition)

                        # True branch
                        paths.append({
                            "path_id": path_id,
                            "conditions": [condition],
                            "state": {var: self._generate_satisfying_value(condition, var, True) 
                                     for var in symbolic_vars},
                            "reachable": True,
                        })
                        path_id += 1

                        # False branch
                        paths.append({
                            "path_id": path_id,
                            "conditions": [f"not ({condition})"],
                            "state": {var: self._generate_satisfying_value(condition, var, False) 
                                     for var in symbolic_vars},
                            "reachable": True,
                        })
                        path_id += 1

                # If no branches, single path
                if not paths:
                    paths.append({
                        "path_id": 0,
                        "conditions": [],
                        "state": {var: 0 for var in symbolic_vars},
                        "reachable": True,
                    })

            except SyntaxError:
                pass

        return {
            "paths": paths,
            "symbolic_vars": symbolic_vars,
            "constraints": constraints,
        }

    def _generate_satisfying_value(
        self, condition: str, var: str, should_satisfy: bool
    ) -> Any:
        """Generate a value that satisfies (or doesn't satisfy) a condition."""
        # Parse common patterns
        # x > 0, x < 0, x == 0, x >= N, x <= N
        patterns = [
            (rf"{var}\s*>\s*(\d+)", lambda m: int(m.group(1)) + 1 if should_satisfy else int(m.group(1)) - 1),
            (rf"{var}\s*<\s*(\d+)", lambda m: int(m.group(1)) - 1 if should_satisfy else int(m.group(1)) + 1),
            (rf"{var}\s*>=\s*(\d+)", lambda m: int(m.group(1)) if should_satisfy else int(m.group(1)) - 1),
            (rf"{var}\s*<=\s*(\d+)", lambda m: int(m.group(1)) if should_satisfy else int(m.group(1)) + 1),
            (rf"{var}\s*==\s*(\d+)", lambda m: int(m.group(1)) if should_satisfy else int(m.group(1)) + 1),
            (rf"{var}\s*!=\s*(\d+)", lambda m: int(m.group(1)) + 1 if should_satisfy else int(m.group(1))),
        ]

        for pattern, value_fn in patterns:
            match = re.search(pattern, condition)
            if match:
                return value_fn(match)

        # Default values
        return 1 if should_satisfy else -1

    def _extract_test_cases(
        self,
        symbolic_result: dict[str, Any],
        function_name: str,
        code: str,
        language: str,
    ) -> list[TestCase]:
        """Extract test cases from symbolic execution result."""
        test_cases = []
        paths = symbolic_result.get("paths", [])
        symbolic_vars = symbolic_result.get("symbolic_vars", [])

        for path in paths:
            path_id = path.get("path_id", len(test_cases))
            conditions = path.get("conditions", [])
            state = path.get("state", {})
            reachable = path.get("reachable", True)

            if not reachable:
                continue

            # Extract reproduction inputs
            inputs = {}
            for var in symbolic_vars:
                if var in state:
                    value = state[var]
                    # Convert Z3 values to Python if needed
                    inputs[var] = self._to_python_value(value)

            # Generate description
            if conditions:
                description = f"Triggers path where {' and '.join(conditions[:2])}"
                if len(conditions) > 2:
                    description += f" (and {len(conditions) - 2} more conditions)"
            else:
                description = "Default/linear execution path"

            test_cases.append(TestCase(
                path_id=path_id,
                function_name=function_name,
                inputs=inputs,
                expected_behavior="Executes without error",
                path_conditions=conditions,
                description=description,
            ))

        # Ensure at least one test case
        if not test_cases:
            test_cases.append(TestCase(
                path_id=0,
                function_name=function_name,
                inputs={},
                expected_behavior="Executes without error",
                path_conditions=[],
                description="Basic execution test",
            ))

        return test_cases

    def _to_python_value(self, value: Any) -> Any:
        """Convert Z3 or other symbolic values to Python natives."""
        if hasattr(value, "as_long"):
            # Z3 IntNumRef
            return value.as_long()
        if hasattr(value, "as_string"):
            # Z3 StringVal
            return value.as_string()
        if hasattr(value, "is_true"):
            # Z3 BoolRef
            return bool(value)
        if isinstance(value, str):
            # Already a string
            try:
                return int(value)
            except ValueError:
                try:
                    return float(value)
                except ValueError:
                    return value
        return value
