"""
MySQL MCP Tools for CQA Test Application
FastMCP-based MySQL integration for database query operations
"""

import os
import logging
from typing import List, Dict, Any
import mysql.connector
from mysql.connector import Error
from fastmcp import FastMCP

logger = logging.getLogger(__name__)

# MySQL Configuration from environment variables
MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost")
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
MYSQL_USER = os.getenv("MYSQL_USER", "root")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "")

# Create FastMCP instance for MySQL tools
mysql_mcp = FastMCP("CQA MySQL MCP Server")


def get_mysql_connection():
    """Create and return a MySQL database connection"""
    try:
        connection = mysql.connector.connect(
            host=MYSQL_HOST,
            port=MYSQL_PORT,
            user=MYSQL_USER,
            password=MYSQL_PASSWORD,
            database=MYSQL_DATABASE if MYSQL_DATABASE else None,
        )
        if connection.is_connected():
            logger.info(f"Successfully connected to MySQL at {MYSQL_HOST}:{MYSQL_PORT}")
            return connection
    except Error as e:
        logger.error(f"Error connecting to MySQL: {e}")
        raise


@mysql_mcp.tool
def mysql_execute_query(query: str, database: str = None) -> Dict[str, Any]:
    """Execute a SELECT query and return results
    
    Args:
        query: SQL SELECT query to execute
        database: Optional database name to use (overrides env variable)
    
    Returns:
        Dictionary containing:
        - columns: List of column names
        - rows: List of row data (each row is a list)
        - row_count: Number of rows returned
    """
    connection = None
    cursor = None
    
    try:
        # Validate query is a SELECT statement for safety
        query_upper = query.strip().upper()
        if not query_upper.startswith("SELECT") and not query_upper.startswith("SHOW") and not query_upper.startswith("DESCRIBE"):
            return {
                "error": "Only SELECT, SHOW, and DESCRIBE queries are allowed for safety",
                "columns": [],
                "rows": [],
                "row_count": 0
            }
        
        connection = mysql.connector.connect(
            host=MYSQL_HOST,
            port=MYSQL_PORT,
            user=MYSQL_USER,
            password=MYSQL_PASSWORD,
            database=database if database else MYSQL_DATABASE,
        )
        
        cursor = connection.cursor()
        cursor.execute(query)
        
        # Fetch results
        columns = [desc[0] for desc in cursor.description] if cursor.description else []
        rows = cursor.fetchall()
        
        # Convert rows to list of lists for JSON serialization
        rows_data = [list(row) for row in rows]
        
        logger.info(f"Query executed successfully: {len(rows_data)} rows returned")
        
        return {
            "columns": columns,
            "rows": rows_data,
            "row_count": len(rows_data)
        }
        
    except Error as e:
        error_msg = f"MySQL Error: {str(e)}"
        logger.error(error_msg)
        return {
            "error": error_msg,
            "columns": [],
            "rows": [],
            "row_count": 0
        }
    finally:
        if cursor:
            cursor.close()
        if connection and connection.is_connected():
            connection.close()


@mysql_mcp.tool
def mysql_list_databases() -> Dict[str, Any]:
    """List all databases on the MySQL server
    
    Returns:
        Dictionary containing list of database names
    """
    connection = None
    cursor = None
    
    try:
        connection = get_mysql_connection()
        cursor = connection.cursor()
        cursor.execute("SHOW DATABASES")
        
        databases = [db[0] for db in cursor.fetchall()]
        
        logger.info(f"Retrieved {len(databases)} databases")
        
        return {
            "databases": databases,
            "count": len(databases)
        }
        
    except Error as e:
        error_msg = f"MySQL Error: {str(e)}"
        logger.error(error_msg)
        return {
            "error": error_msg,
            "databases": [],
            "count": 0
        }
    finally:
        if cursor:
            cursor.close()
        if connection and connection.is_connected():
            connection.close()


@mysql_mcp.tool
def mysql_list_tables(database: str = None) -> Dict[str, Any]:
    """List all tables in a database
    
    Args:
        database: Database name (uses env variable if not provided)
    
    Returns:
        Dictionary containing list of table names
    """
    connection = None
    cursor = None
    
    try:
        db_name = database if database else MYSQL_DATABASE
        if not db_name:
            return {
                "error": "No database specified. Provide database parameter or set MYSQL_DATABASE env variable",
                "tables": [],
                "count": 0
            }
        
        connection = mysql.connector.connect(
            host=MYSQL_HOST,
            port=MYSQL_PORT,
            user=MYSQL_USER,
            password=MYSQL_PASSWORD,
            database=db_name,
        )
        
        cursor = connection.cursor()
        cursor.execute("SHOW TABLES")
        
        tables = [table[0] for table in cursor.fetchall()]
        
        logger.info(f"Retrieved {len(tables)} tables from database '{db_name}'")
        
        return {
            "database": db_name,
            "tables": tables,
            "count": len(tables)
        }
        
    except Error as e:
        error_msg = f"MySQL Error: {str(e)}"
        logger.error(error_msg)
        return {
            "error": error_msg,
            "tables": [],
            "count": 0
        }
    finally:
        if cursor:
            cursor.close()
        if connection and connection.is_connected():
            connection.close()


@mysql_mcp.tool
def mysql_describe_table(table_name: str, database: str = None) -> Dict[str, Any]:
    """Describe table structure (columns, types, etc.)
    
    Args:
        table_name: Name of the table to describe
        database: Database name (uses env variable if not provided)
    
    Returns:
        Dictionary containing table structure information
    """
    connection = None
    cursor = None
    
    try:
        db_name = database if database else MYSQL_DATABASE
        if not db_name:
            return {
                "error": "No database specified. Provide database parameter or set MYSQL_DATABASE env variable",
                "columns": []
            }
        
        connection = mysql.connector.connect(
            host=MYSQL_HOST,
            port=MYSQL_PORT,
            user=MYSQL_USER,
            password=MYSQL_PASSWORD,
            database=db_name,
        )
        
        cursor = connection.cursor()
        cursor.execute(f"DESCRIBE {table_name}")
        
        columns_info = []
        for row in cursor.fetchall():
            columns_info.append({
                "field": row[0],
                "type": row[1],
                "null": row[2],
                "key": row[3],
                "default": row[4],
                "extra": row[5]
            })
        
        logger.info(f"Described table '{table_name}' in database '{db_name}'")
        
        return {
            "database": db_name,
            "table": table_name,
            "columns": columns_info,
            "column_count": len(columns_info)
        }
        
    except Error as e:
        error_msg = f"MySQL Error: {str(e)}"
        logger.error(error_msg)
        return {
            "error": error_msg,
            "columns": []
        }
    finally:
        if cursor:
            cursor.close()
        if connection and connection.is_connected():
            connection.close()


@mysql_mcp.tool
def mysql_execute_write(query: str, database: str = None) -> Dict[str, Any]:
    """Execute a write query (INSERT, UPDATE, DELETE, CREATE, etc.)
    
    Args:
        query: SQL write query to execute
        database: Optional database name to use (overrides env variable)
    
    Returns:
        Dictionary containing:
        - success: Boolean indicating success
        - affected_rows: Number of rows affected
        - message: Success or error message
    """
    connection = None
    cursor = None
    
    try:
        connection = mysql.connector.connect(
            host=MYSQL_HOST,
            port=MYSQL_PORT,
            user=MYSQL_USER,
            password=MYSQL_PASSWORD,
            database=database if database else MYSQL_DATABASE,
            autocommit=False  # Use transactions
        )
        
        cursor = connection.cursor()
        cursor.execute(query)
        connection.commit()
        
        affected_rows = cursor.rowcount
        
        logger.info(f"Write query executed successfully: {affected_rows} rows affected")
        
        return {
            "success": True,
            "affected_rows": affected_rows,
            "message": f"Query executed successfully. {affected_rows} rows affected."
        }
        
    except Error as e:
        if connection:
            connection.rollback()
        error_msg = f"MySQL Error: {str(e)}"
        logger.error(error_msg)
        return {
            "success": False,
            "affected_rows": 0,
            "message": error_msg
        }
    finally:
        if cursor:
            cursor.close()
        if connection and connection.is_connected():
            connection.close()


def get_mysql_mcp_server():
    """Return the MySQL FastMCP server instance"""
    logger.info("Getting MySQL MCP server instance")
    return mysql_mcp
