#!/usr/bin/env python3
"""
Test script for CQA MCP Combined Server
Tests both Jira and Flava Functions tools integration
"""

import os
import sys
import asyncio
import logging
from typing import Dict, Any

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Add current directory to path
sys.path.insert(0, '.')

def setup_test_environment():
    """Setup test environment variables"""
    # Basic test environment
    os.environ['JIRA_URL'] = 'https://test.jira.example.com'
    os.environ['JIRA_PERSONAL_TOKEN'] = 'test-jira-token'
    os.environ['FLAVA_FUNCTION_BASE_URL'] = 'https://test.flava.example.com'
    os.environ['FLAVA_PRODUCT'] = 'faas'
    os.environ['FLAVA_ENVIRONMENT'] = 'test'
    os.environ['FLAVA_PROJECT'] = 'test-project'
    logger.info("✅ Test environment variables set")

async def test_mounted_mcp_server():
    """Test the mounted MCP server setup and tool listing"""
    try:
        from products.mcp.jira import get_jira_mcp_server
        from products.mcp.flava_function import get_flava_function_mcp_server
        from fastmcp import FastMCP

        logger.info("🔧 Creating main MCP server with mounted services...")

        # Create main server
        main_mcp = FastMCP("Test Main MCP Server")

        # Get individual servers
        jira_server = get_jira_mcp_server()
        flava_server = get_flava_function_mcp_server()

        logger.info(f"📋 Jira server: {jira_server}")
        logger.info(f"📋 Flava server: {flava_server}")

        # Mount servers with prefixes (using new FastMCP API)
        main_mcp.mount(jira_server, prefix="jira")
        logger.info("✅ Mounted Jira MCP server at 'jira' namespace")

        main_mcp.mount(flava_server, prefix="flava")
        logger.info("✅ Mounted Flava Functions MCP server at 'flava' namespace")

        logger.info("🔧 Successfully created main MCP server with mounted services")

        # Get tools list
        tools = await main_mcp.get_tools()
        tool_names = list(tools.keys()) if isinstance(tools, dict) else tools

        logger.info("📋 Available tools:")
        for i, tool_name in enumerate(tool_names, 1):
            logger.info(f"   {i}. {tool_name}")

        # FastMCP mount uses "prefix_toolname" format, not "prefix.toolname"
        expected_mounted_jira_tools = [
            'jira_jira_list_projects', 'jira_jira_search_issues', 'jira_jira_get_issue',
            'jira_jira_get_comments', 'jira_jira_create_issue'
        ]
        expected_mounted_flava_tools = ['flava_flava_get_functions', 'flava_flava_get_function_details']

        # Check if tools are available
        available_tools = set(tool_names)

        # Check for mounted tools with prefix_toolname format
        mounted_missing = []
        for tool in expected_mounted_jira_tools + expected_mounted_flava_tools:
            if tool not in available_tools:
                mounted_missing.append(tool)

        # Determine if we have the expected tools
        if len(mounted_missing) == 0:
            logger.info("✅ All expected mounted tools are available")
            return True
        else:
            logger.error(f"❌ Missing mounted tools: {mounted_missing}")
            logger.info(f"📋 Available tools: {sorted(list(available_tools))}")

            # Check if we have the right number of tools at least
            expected_total = len(expected_mounted_jira_tools) + len(expected_mounted_flava_tools)
            actual_total = len(available_tools)

            if actual_total == expected_total:
                logger.info(f"✅ Correct number of tools ({actual_total}) but naming might be different")
                return True
            else:
                logger.error(f"❌ Expected {expected_total} tools but got {actual_total}")
                return False

    except Exception as e:
        logger.error(f"❌ Error testing mounted MCP server: {str(e)}")
        import traceback
        traceback.print_exc()
        return False

def test_module_info():
    """Test getting module information"""
    try:
        from products.mcp.jira import get_mcp_info
        from products.mcp.flava_function import get_flava_function_mcp_info

        logger.info("📊 Module Information:")

        # Test Jira module info
        jira_info = get_mcp_info()
        logger.info(f"   Jira MCP: {jira_info}")

        # Test Flava Functions module info
        flava_info = get_flava_function_mcp_info()
        logger.info(f"   Flava Functions MCP: {flava_info}")

        return True

    except Exception as e:
        logger.error(f"❌ Error getting module info: {str(e)}")
        return False

async def main():
    """Main test function"""
    logger.info("🚀 Starting CQA MCP Combined Server Tests")

    # Setup test environment
    setup_test_environment()

    # Run tests
    tests = [
        ("Module Information", test_module_info()),
        ("Mounted MCP Server", await test_mounted_mcp_server()),
    ]

    # Report results
    passed = 0
    total = len(tests)

    logger.info("\n📊 Test Results:")
    for test_name, result in tests:
        status = "✅ PASS" if result else "❌ FAIL"
        logger.info(f"   {test_name}: {status}")
        if result:
            passed += 1

    logger.info(f"\n🎯 Overall: {passed}/{total} tests passed")

    if passed == total:
        logger.info("🎉 All tests passed! MCP server is ready.")
        return 0
    else:
        logger.error("💥 Some tests failed. Check the logs above.")
        return 1

if __name__ == "__main__":
    exit_code = asyncio.run(main())
    sys.exit(exit_code)