#!/usr/bin/env python3
"""
Database Migration Tool

Analyzes SQL schema files, detects potential issues, suggests indexes,
and generates migration scripts with rollback support.

Usage:
    python database_migration_tool.py schema.sql --analyze
    python database_migration_tool.py old.sql --compare new.sql --output migrations/
    python database_migration_tool.py schema.sql --suggest-indexes
"""

import os
import sys
import json
import argparse
import re
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from datetime import datetime
from dataclasses import dataclass, field, asdict


@dataclass
class Column:
    """Database column definition."""
    name: str
    data_type: str
    nullable: bool = True
    default: Optional[str] = None
    primary_key: bool = False
    unique: bool = False
    references: Optional[str] = None


@dataclass
class Index:
    """Database index definition."""
    name: str
    table: str
    columns: List[str]
    unique: bool = False
    partial: Optional[str] = None


@dataclass
class Table:
    """Database table definition."""
    name: str
    columns: Dict[str, Column] = field(default_factory=dict)
    indexes: List[Index] = field(default_factory=list)
    primary_key: List[str] = field(default_factory=list)
    foreign_keys: List[Dict] = field(default_factory=list)


@dataclass
class Issue:
    """Schema issue or recommendation."""
    severity: str  # 'error', 'warning', 'info'
    category: str  # 'index', 'naming', 'type', 'constraint'
    table: str
    message: str
    suggestion: Optional[str] = None


class SQLParser:
    """Parse SQL DDL statements."""

    # Common patterns
    CREATE_TABLE_PATTERN = re.compile(
        r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["`]?(\w+)["`]?\s*\((.*?)\)\s*;',
        re.IGNORECASE | re.DOTALL
    )

    CREATE_INDEX_PATTERN = re.compile(
        r'CREATE\s+(UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?["`]?(\w+)["`]?\s+'
        r'ON\s+["`]?(\w+)["`]?\s*\(([^)]+)\)(?:\s+WHERE\s+(.+?))?;',
        re.IGNORECASE | re.DOTALL
    )

    COLUMN_PATTERN = re.compile(
        r'["`]?(\w+)["`]?\s+'  # Column name
        r'(\w+(?:\s*\([^)]+\))?)'  # Data type
        r'([^,]*)',  # Constraints
        re.IGNORECASE
    )

    FK_PATTERN = re.compile(
        r'FOREIGN\s+KEY\s*\(["`]?(\w+)["`]?\)\s+'
        r'REFERENCES\s+["`]?(\w+)["`]?\s*\(["`]?(\w+)["`]?\)',
        re.IGNORECASE
    )

    def parse(self, sql: str) -> Dict[str, Table]:
        """Parse SQL and return table definitions."""
        tables = {}

        # Parse CREATE TABLE statements
        for match in self.CREATE_TABLE_PATTERN.finditer(sql):
            table_name = match.group(1)
            body = match.group(2)
            table = self._parse_table_body(table_name, body)
            tables[table_name] = table

        # Parse CREATE INDEX statements
        for match in self.CREATE_INDEX_PATTERN.finditer(sql):
            unique = bool(match.group(1))
            index_name = match.group(2)
            table_name = match.group(3)
            columns = [c.strip().strip('"`') for c in match.group(4).split(',')]
            where_clause = match.group(5)

            index = Index(
                name=index_name,
                table=table_name,
                columns=columns,
                unique=unique,
                partial=where_clause.strip() if where_clause else None
            )

            if table_name in tables:
                tables[table_name].indexes.append(index)

        return tables

    def _parse_table_body(self, table_name: str, body: str) -> Table:
        """Parse table body (columns, constraints)."""
        table = Table(name=table_name)

        # Split by comma, but respect parentheses
        parts = self._split_by_comma(body)

        for part in parts:
            part = part.strip()

            # Skip empty parts
            if not part:
                continue

            # Check for PRIMARY KEY constraint
            if part.upper().startswith('PRIMARY KEY'):
                pk_match = re.search(r'PRIMARY\s+KEY\s*\(([^)]+)\)', part, re.IGNORECASE)
                if pk_match:
                    cols = [c.strip().strip('"`') for c in pk_match.group(1).split(',')]
                    table.primary_key = cols

            # Check for FOREIGN KEY constraint
            elif part.upper().startswith('FOREIGN KEY'):
                fk_match = self.FK_PATTERN.search(part)
                if fk_match:
                    table.foreign_keys.append({
                        'column': fk_match.group(1),
                        'ref_table': fk_match.group(2),
                        'ref_column': fk_match.group(3),
                    })

            # Check for CONSTRAINT
            elif part.upper().startswith('CONSTRAINT'):
                # Handle named constraints
                if 'PRIMARY KEY' in part.upper():
                    pk_match = re.search(r'PRIMARY\s+KEY\s*\(([^)]+)\)', part, re.IGNORECASE)
                    if pk_match:
                        cols = [c.strip().strip('"`') for c in pk_match.group(1).split(',')]
                        table.primary_key = cols
                elif 'FOREIGN KEY' in part.upper():
                    fk_match = self.FK_PATTERN.search(part)
                    if fk_match:
                        table.foreign_keys.append({
                            'column': fk_match.group(1),
                            'ref_table': fk_match.group(2),
                            'ref_column': fk_match.group(3),
                        })

            # Regular column definition
            else:
                col_match = self.COLUMN_PATTERN.match(part)
                if col_match:
                    col_name = col_match.group(1)
                    col_type = col_match.group(2)
                    constraints = col_match.group(3).upper() if col_match.group(3) else ''

                    column = Column(
                        name=col_name,
                        data_type=col_type.upper(),
                        nullable='NOT NULL' not in constraints,
                        primary_key='PRIMARY KEY' in constraints,
                        unique='UNIQUE' in constraints,
                    )

                    # Extract default value
                    default_match = re.search(r'DEFAULT\s+(\S+)', constraints, re.IGNORECASE)
                    if default_match:
                        column.default = default_match.group(1)

                    # Extract references
                    ref_match = re.search(
                        r'REFERENCES\s+["`]?(\w+)["`]?\s*\(["`]?(\w+)["`]?\)',
                        constraints,
                        re.IGNORECASE
                    )
                    if ref_match:
                        column.references = f"{ref_match.group(1)}({ref_match.group(2)})"
                        table.foreign_keys.append({
                            'column': col_name,
                            'ref_table': ref_match.group(1),
                            'ref_column': ref_match.group(2),
                        })

                    if column.primary_key and col_name not in table.primary_key:
                        table.primary_key.append(col_name)

                    table.columns[col_name] = column

        return table

    def _split_by_comma(self, s: str) -> List[str]:
        """Split string by comma, respecting parentheses."""
        parts = []
        current = []
        depth = 0

        for char in s:
            if char == '(':
                depth += 1
            elif char == ')':
                depth -= 1
            elif char == ',' and depth == 0:
                parts.append(''.join(current))
                current = []
                continue
            current.append(char)

        if current:
            parts.append(''.join(current))

        return parts


class SchemaAnalyzer:
    """Analyze database schema for issues and optimizations."""

    # Columns that typically need indexes (foreign keys)
    FK_COLUMN_PATTERNS = ['_id', 'Id', '_ID']

    # Columns that typically need indexes for filtering
    FILTER_COLUMN_PATTERNS = ['status', 'state', 'type', 'category', 'active', 'enabled', 'deleted']

    # Columns that typically need indexes for sorting/ordering
    SORT_COLUMN_PATTERNS = ['created_at', 'updated_at', 'date', 'timestamp', 'order', 'position']

    def __init__(self, tables: Dict[str, Table]):
        self.tables = tables
        self.issues: List[Issue] = []

    def analyze(self) -> List[Issue]:
        """Run all analysis checks."""
        self.issues = []

        for table_name, table in self.tables.items():
            self._check_naming_conventions(table)
            self._check_primary_key(table)
            self._check_foreign_key_indexes(table)
            self._check_common_filter_columns(table)
            self._check_timestamp_columns(table)
            self._check_data_types(table)

        return self.issues

    def _check_naming_conventions(self, table: Table):
        """Check table and column naming conventions."""
        # Table name should be lowercase
        if table.name != table.name.lower():
            self.issues.append(Issue(
                severity='warning',
                category='naming',
                table=table.name,
                message=f"Table name '{table.name}' should be lowercase",
                suggestion=f"Rename to '{table.name.lower()}'"
            ))

        # Table name should be plural (basic check)
        if not table.name.endswith('s') and not table.name.endswith('es'):
            self.issues.append(Issue(
                severity='info',
                category='naming',
                table=table.name,
                message=f"Table name '{table.name}' should typically be plural",
            ))

        for col_name, col in table.columns.items():
            # Column names should be lowercase with underscores
            if col_name != col_name.lower():
                self.issues.append(Issue(
                    severity='warning',
                    category='naming',
                    table=table.name,
                    message=f"Column '{col_name}' should use snake_case",
                    suggestion=f"Rename to '{self._to_snake_case(col_name)}'"
                ))

    def _check_primary_key(self, table: Table):
        """Check for missing primary key."""
        if not table.primary_key:
            self.issues.append(Issue(
                severity='error',
                category='constraint',
                table=table.name,
                message=f"Table '{table.name}' has no primary key",
                suggestion="Add a primary key column (e.g., 'id SERIAL PRIMARY KEY')"
            ))

    def _check_foreign_key_indexes(self, table: Table):
        """Check that foreign key columns have indexes."""
        indexed_columns = set()
        for index in table.indexes:
            indexed_columns.update(index.columns)

        # Primary key columns are implicitly indexed
        indexed_columns.update(table.primary_key)

        for fk in table.foreign_keys:
            fk_col = fk['column']
            if fk_col not in indexed_columns:
                self.issues.append(Issue(
                    severity='warning',
                    category='index',
                    table=table.name,
                    message=f"Foreign key column '{fk_col}' is not indexed",
                    suggestion=f"CREATE INDEX idx_{table.name}_{fk_col} ON {table.name}({fk_col});"
                ))

        # Also check columns that look like foreign keys but aren't declared
        for col_name in table.columns:
            if any(col_name.endswith(pattern) for pattern in self.FK_COLUMN_PATTERNS):
                if col_name not in indexed_columns:
                    # Check if it's actually a declared FK
                    is_declared_fk = any(fk['column'] == col_name for fk in table.foreign_keys)
                    if not is_declared_fk:
                        self.issues.append(Issue(
                            severity='info',
                            category='index',
                            table=table.name,
                            message=f"Column '{col_name}' looks like a foreign key but has no index",
                            suggestion=f"CREATE INDEX idx_{table.name}_{col_name} ON {table.name}({col_name});"
                        ))

    def _check_common_filter_columns(self, table: Table):
        """Check for indexes on commonly filtered columns."""
        indexed_columns = set()
        for index in table.indexes:
            indexed_columns.update(index.columns)
        indexed_columns.update(table.primary_key)

        for col_name in table.columns:
            col_lower = col_name.lower()
            if any(pattern in col_lower for pattern in self.FILTER_COLUMN_PATTERNS):
                if col_name not in indexed_columns:
                    self.issues.append(Issue(
                        severity='info',
                        category='index',
                        table=table.name,
                        message=f"Column '{col_name}' is commonly used for filtering but has no index",
                        suggestion=f"CREATE INDEX idx_{table.name}_{col_name} ON {table.name}({col_name});"
                    ))

    def _check_timestamp_columns(self, table: Table):
        """Check for indexes on timestamp columns used for sorting."""
        has_created_at = 'created_at' in table.columns
        has_updated_at = 'updated_at' in table.columns

        if not has_created_at:
            self.issues.append(Issue(
                severity='info',
                category='convention',
                table=table.name,
                message=f"Table '{table.name}' has no 'created_at' column",
                suggestion="Consider adding: created_at TIMESTAMP DEFAULT NOW()"
            ))

        if not has_updated_at:
            self.issues.append(Issue(
                severity='info',
                category='convention',
                table=table.name,
                message=f"Table '{table.name}' has no 'updated_at' column",
                suggestion="Consider adding: updated_at TIMESTAMP DEFAULT NOW()"
            ))

    def _check_data_types(self, table: Table):
        """Check for potential data type issues."""
        for col_name, col in table.columns.items():
            dtype = col.data_type.upper()

            # Check for VARCHAR without length
            if 'VARCHAR' in dtype and '(' not in dtype:
                self.issues.append(Issue(
                    severity='warning',
                    category='type',
                    table=table.name,
                    message=f"Column '{col_name}' uses VARCHAR without length",
                    suggestion="Specify a maximum length, e.g., VARCHAR(255)"
                ))

            # Check for FLOAT/DOUBLE for monetary values
            if 'FLOAT' in dtype or 'DOUBLE' in dtype:
                if 'price' in col_name.lower() or 'amount' in col_name.lower() or 'total' in col_name.lower():
                    self.issues.append(Issue(
                        severity='warning',
                        category='type',
                        table=table.name,
                        message=f"Column '{col_name}' uses floating point for monetary value",
                        suggestion="Use DECIMAL or NUMERIC for monetary values"
                    ))

            # Check for TEXT columns that might benefit from length limits
            if dtype == 'TEXT':
                if 'email' in col_name.lower() or 'url' in col_name.lower():
                    self.issues.append(Issue(
                        severity='info',
                        category='type',
                        table=table.name,
                        message=f"Column '{col_name}' uses TEXT but might benefit from VARCHAR",
                        suggestion=f"Consider VARCHAR(255) for {col_name}"
                    ))

    def _to_snake_case(self, name: str) -> str:
        """Convert name to snake_case."""
        s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
        return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


class MigrationGenerator:
    """Generate migration scripts from schema differences."""

    def __init__(self, old_tables: Dict[str, Table], new_tables: Dict[str, Table]):
        self.old_tables = old_tables
        self.new_tables = new_tables

    def generate(self) -> Tuple[str, str]:
        """Generate UP and DOWN migration scripts."""
        up_statements = []
        down_statements = []

        # Find new tables
        for table_name, table in self.new_tables.items():
            if table_name not in self.old_tables:
                up_statements.append(self._generate_create_table(table))
                down_statements.append(f"DROP TABLE IF EXISTS {table_name};")

        # Find removed tables
        for table_name, table in self.old_tables.items():
            if table_name not in self.new_tables:
                up_statements.append(f"DROP TABLE IF EXISTS {table_name};")
                down_statements.append(self._generate_create_table(table))

        # Find modified tables
        for table_name in set(self.old_tables.keys()) & set(self.new_tables.keys()):
            old_table = self.old_tables[table_name]
            new_table = self.new_tables[table_name]
            up, down = self._compare_tables(old_table, new_table)
            up_statements.extend(up)
            down_statements.extend(down)

        up_sql = '\n\n'.join(up_statements) if up_statements else '-- No changes'
        down_sql = '\n\n'.join(down_statements) if down_statements else '-- No changes'

        return up_sql, down_sql

    def _generate_create_table(self, table: Table) -> str:
        """Generate CREATE TABLE statement."""
        lines = [f"CREATE TABLE {table.name} ("]

        col_defs = []
        for col_name, col in table.columns.items():
            col_def = f"  {col_name} {col.data_type}"
            if not col.nullable:
                col_def += " NOT NULL"
            if col.default:
                col_def += f" DEFAULT {col.default}"
            if col.primary_key and len(table.primary_key) == 1:
                col_def += " PRIMARY KEY"
            if col.unique:
                col_def += " UNIQUE"
            col_defs.append(col_def)

        # Add composite primary key
        if len(table.primary_key) > 1:
            pk_cols = ', '.join(table.primary_key)
            col_defs.append(f"  PRIMARY KEY ({pk_cols})")

        # Add foreign keys
        for fk in table.foreign_keys:
            col_defs.append(
                f"  FOREIGN KEY ({fk['column']}) REFERENCES {fk['ref_table']}({fk['ref_column']})"
            )

        lines.append(',\n'.join(col_defs))
        lines.append(");")

        return '\n'.join(lines)

    def _compare_tables(self, old: Table, new: Table) -> Tuple[List[str], List[str]]:
        """Compare two tables and generate ALTER statements."""
        up = []
        down = []

        # New columns
        for col_name, col in new.columns.items():
            if col_name not in old.columns:
                up.append(f"ALTER TABLE {new.name} ADD COLUMN {col_name} {col.data_type}"
                         + (" NOT NULL" if not col.nullable else "")
                         + (f" DEFAULT {col.default}" if col.default else "") + ";")
                down.append(f"ALTER TABLE {new.name} DROP COLUMN IF EXISTS {col_name};")

        # Removed columns
        for col_name, col in old.columns.items():
            if col_name not in new.columns:
                up.append(f"ALTER TABLE {old.name} DROP COLUMN IF EXISTS {col_name};")
                down.append(f"ALTER TABLE {old.name} ADD COLUMN {col_name} {col.data_type}"
                           + (" NOT NULL" if not col.nullable else "")
                           + (f" DEFAULT {col.default}" if col.default else "") + ";")

        # Modified columns (type changes)
        for col_name in set(old.columns.keys()) & set(new.columns.keys()):
            old_col = old.columns[col_name]
            new_col = new.columns[col_name]

            if old_col.data_type != new_col.data_type:
                up.append(f"ALTER TABLE {new.name} ALTER COLUMN {col_name} TYPE {new_col.data_type};")
                down.append(f"ALTER TABLE {old.name} ALTER COLUMN {col_name} TYPE {old_col.data_type};")

        # New indexes
        old_index_names = {idx.name for idx in old.indexes}
        for idx in new.indexes:
            if idx.name not in old_index_names:
                unique = "UNIQUE " if idx.unique else ""
                cols = ', '.join(idx.columns)
                where = f" WHERE {idx.partial}" if idx.partial else ""
                up.append(f"CREATE {unique}INDEX CONCURRENTLY {idx.name} ON {idx.table}({cols}){where};")
                down.append(f"DROP INDEX IF EXISTS {idx.name};")

        # Removed indexes
        new_index_names = {idx.name for idx in new.indexes}
        for idx in old.indexes:
            if idx.name not in new_index_names:
                unique = "UNIQUE " if idx.unique else ""
                cols = ', '.join(idx.columns)
                where = f" WHERE {idx.partial}" if idx.partial else ""
                up.append(f"DROP INDEX IF EXISTS {idx.name};")
                down.append(f"CREATE {unique}INDEX {idx.name} ON {idx.table}({cols}){where};")

        return up, down


class DatabaseMigrationTool:
    """Main tool for database migration analysis."""

    def __init__(self, schema_path: str, compare_path: Optional[str] = None,
                 output_dir: Optional[str] = None, verbose: bool = False):
        self.schema_path = Path(schema_path)
        self.compare_path = Path(compare_path) if compare_path else None
        self.output_dir = Path(output_dir) if output_dir else None
        self.verbose = verbose
        self.parser = SQLParser()

    def run(self, mode: str = 'analyze') -> Dict:
        """Execute the tool in specified mode."""
        print(f"Database Migration Tool")
        print(f"Schema: {self.schema_path}")
        print("-" * 50)

        if not self.schema_path.exists():
            raise FileNotFoundError(f"Schema file not found: {self.schema_path}")

        schema_sql = self.schema_path.read_text()
        tables = self.parser.parse(schema_sql)

        if self.verbose:
            print(f"Parsed {len(tables)} tables")

        if mode == 'analyze':
            return self._analyze(tables)
        elif mode == 'compare':
            return self._compare(tables)
        elif mode == 'suggest-indexes':
            return self._suggest_indexes(tables)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def _analyze(self, tables: Dict[str, Table]) -> Dict:
        """Analyze schema for issues."""
        analyzer = SchemaAnalyzer(tables)
        issues = analyzer.analyze()

        # Group by severity
        errors = [i for i in issues if i.severity == 'error']
        warnings = [i for i in issues if i.severity == 'warning']
        infos = [i for i in issues if i.severity == 'info']

        print(f"\nAnalysis Results:")
        print(f"  Tables: {len(tables)}")
        print(f"  Errors: {len(errors)}")
        print(f"  Warnings: {len(warnings)}")
        print(f"  Suggestions: {len(infos)}")

        if errors:
            print(f"\nERRORS:")
            for issue in errors:
                print(f"  [{issue.table}] {issue.message}")
                if issue.suggestion:
                    print(f"    Suggestion: {issue.suggestion}")

        if warnings:
            print(f"\nWARNINGS:")
            for issue in warnings:
                print(f"  [{issue.table}] {issue.message}")
                if issue.suggestion:
                    print(f"    Suggestion: {issue.suggestion}")

        if self.verbose and infos:
            print(f"\nSUGGESTIONS:")
            for issue in infos:
                print(f"  [{issue.table}] {issue.message}")
                if issue.suggestion:
                    print(f"    {issue.suggestion}")

        return {
            'status': 'success',
            'tables_count': len(tables),
            'issues': {
                'errors': len(errors),
                'warnings': len(warnings),
                'suggestions': len(infos),
            },
            'issues_detail': [asdict(i) for i in issues],
        }

    def _compare(self, old_tables: Dict[str, Table]) -> Dict:
        """Compare two schemas and generate migration."""
        if not self.compare_path:
            raise ValueError("Compare path required for compare mode")

        if not self.compare_path.exists():
            raise FileNotFoundError(f"Compare file not found: {self.compare_path}")

        new_sql = self.compare_path.read_text()
        new_tables = self.parser.parse(new_sql)

        generator = MigrationGenerator(old_tables, new_tables)
        up_sql, down_sql = generator.generate()

        print(f"\nComparing schemas:")
        print(f"  Old: {self.schema_path}")
        print(f"  New: {self.compare_path}")

        # Calculate changes
        added_tables = set(new_tables.keys()) - set(old_tables.keys())
        removed_tables = set(old_tables.keys()) - set(new_tables.keys())

        print(f"\nChanges detected:")
        print(f"  Added tables: {len(added_tables)}")
        print(f"  Removed tables: {len(removed_tables)}")

        if self.output_dir:
            self.output_dir.mkdir(parents=True, exist_ok=True)
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

            up_file = self.output_dir / f"{timestamp}_migration.sql"
            down_file = self.output_dir / f"{timestamp}_migration_rollback.sql"

            up_file.write_text(f"-- Migration: {self.schema_path} -> {self.compare_path}\n"
                              f"-- Generated: {datetime.now().isoformat()}\n\n"
                              f"BEGIN;\n\n{up_sql}\n\nCOMMIT;\n")

            down_file.write_text(f"-- Rollback for migration {timestamp}\n"
                                f"-- Generated: {datetime.now().isoformat()}\n\n"
                                f"BEGIN;\n\n{down_sql}\n\nCOMMIT;\n")

            print(f"\nGenerated files:")
            print(f"  Migration: {up_file}")
            print(f"  Rollback: {down_file}")
        else:
            print(f"\n--- UP MIGRATION ---")
            print(up_sql)
            print(f"\n--- DOWN MIGRATION ---")
            print(down_sql)

        return {
            'status': 'success',
            'added_tables': list(added_tables),
            'removed_tables': list(removed_tables),
            'up_sql': up_sql,
            'down_sql': down_sql,
        }

    def _suggest_indexes(self, tables: Dict[str, Table]) -> Dict:
        """Generate index suggestions."""
        suggestions = []

        for table_name, table in tables.items():
            # Get existing indexed columns
            indexed = set()
            for idx in table.indexes:
                indexed.update(idx.columns)
            indexed.update(table.primary_key)

            # Suggest indexes for foreign keys
            for fk in table.foreign_keys:
                if fk['column'] not in indexed:
                    suggestions.append({
                        'table': table_name,
                        'column': fk['column'],
                        'reason': 'Foreign key',
                        'sql': f"CREATE INDEX idx_{table_name}_{fk['column']} ON {table_name}({fk['column']});"
                    })

            # Suggest indexes for common patterns
            for col_name in table.columns:
                if col_name in indexed:
                    continue

                col_lower = col_name.lower()

                # Foreign key pattern
                if col_name.endswith('_id') and col_name not in indexed:
                    suggestions.append({
                        'table': table_name,
                        'column': col_name,
                        'reason': 'Likely foreign key',
                        'sql': f"CREATE INDEX idx_{table_name}_{col_name} ON {table_name}({col_name});"
                    })

                # Status/type columns
                elif col_lower in ['status', 'state', 'type', 'category']:
                    suggestions.append({
                        'table': table_name,
                        'column': col_name,
                        'reason': 'Common filter column',
                        'sql': f"CREATE INDEX idx_{table_name}_{col_name} ON {table_name}({col_name});"
                    })

                # Timestamp columns
                elif col_lower in ['created_at', 'updated_at']:
                    suggestions.append({
                        'table': table_name,
                        'column': col_name,
                        'reason': 'Common sort column',
                        'sql': f"CREATE INDEX idx_{table_name}_{col_name} ON {table_name}({col_name} DESC);"
                    })

        print(f"\nIndex Suggestions ({len(suggestions)} found):")
        for s in suggestions:
            print(f"\n  [{s['table']}.{s['column']}] {s['reason']}")
            print(f"    {s['sql']}")

        if self.output_dir:
            self.output_dir.mkdir(parents=True, exist_ok=True)
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            output_file = self.output_dir / f"{timestamp}_add_indexes.sql"

            lines = [
                f"-- Suggested indexes",
                f"-- Generated: {datetime.now().isoformat()}",
                "",
            ]
            for s in suggestions:
                lines.append(f"-- {s['table']}.{s['column']}: {s['reason']}")
                lines.append(s['sql'])
                lines.append("")

            output_file.write_text('\n'.join(lines))
            print(f"\nWritten to: {output_file}")

        return {
            'status': 'success',
            'suggestions_count': len(suggestions),
            'suggestions': suggestions,
        }


def main():
    """CLI entry point."""
    parser = argparse.ArgumentParser(
        description='Analyze SQL schemas and generate migrations',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog='''
Examples:
  %(prog)s schema.sql --analyze
  %(prog)s old.sql --compare new.sql --output migrations/
  %(prog)s schema.sql --suggest-indexes --output migrations/
        '''
    )

    parser.add_argument(
        'schema',
        help='Path to SQL schema file'
    )
    parser.add_argument(
        '--analyze',
        action='store_true',
        help='Analyze schema for issues and optimizations'
    )
    parser.add_argument(
        '--compare',
        metavar='FILE',
        help='Compare with another schema file and generate migration'
    )
    parser.add_argument(
        '--suggest-indexes',
        action='store_true',
        help='Generate index suggestions'
    )
    parser.add_argument(
        '--output', '-o',
        help='Output directory for generated files'
    )
    parser.add_argument(
        '--verbose', '-v',
        action='store_true',
        help='Enable verbose output'
    )
    parser.add_argument(
        '--json',
        action='store_true',
        help='Output results as JSON'
    )

    args = parser.parse_args()

    # Determine mode
    if args.compare:
        mode = 'compare'
    elif args.suggest_indexes:
        mode = 'suggest-indexes'
    else:
        mode = 'analyze'

    try:
        tool = DatabaseMigrationTool(
            schema_path=args.schema,
            compare_path=args.compare,
            output_dir=args.output,
            verbose=args.verbose,
        )

        results = tool.run(mode=mode)

        if args.json:
            print(json.dumps(results, indent=2))

    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)


if __name__ == '__main__':
    main()
