import json
import re
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure, OperationFailure
from .base_db import BaseDBHandler
from .mongo_constants import (
    PATTERNS, COLUMN_NAMES, CONNECTION_TEMPLATES,
    DEFAULT_DATABASES, RESULT_LIMITS, SPECIAL_METHODS
)
from .mongo_mappings import (
    DB_METHOD_MAPPING, SHOW_MAPPING, RS_METHOD_MAPPING,
    SH_METHOD_MAPPING, COLLECTION_METHOD_MAPPING, COMMAND_PROCESSORS
)

class MongoDBHandler(BaseDBHandler):
    def _safe_json_parse(self, param_str):
        if not param_str or not param_str.strip():
            return None

        param_clean = param_str.strip()
        try:
            return json.loads(param_clean)
        except json.JSONDecodeError:
            try:
                fixed_param = param_clean.replace("'", '"')
                return json.loads(fixed_param)
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON parameter: {param_str}") from e

    def _format_simple_result(self, column_key, result_data):
        if column_key in COLUMN_NAMES:
            return self.make_success_response(
                columns_names=COLUMN_NAMES[column_key],
                results=result_data
            )
        return self.make_error_response(500, f"Unknown result format: {column_key}")

    def _process_command_result(self, db, command_name, command_params=None, processor_key=None):
        result = self._execute_db_command(db, command_name, command_params)

        if processor_key and processor_key in COMMAND_PROCESSORS:
            processor = COMMAND_PROCESSORS[processor_key]
            return self.make_success_response(
                columns_names=processor['result_keys'],
                results=processor['processor'](result)
            )
        return result

    def _execute_db_command(self, db, command_name, command_params=None):
        if command_params:
            return db.command(command_name, **command_params)
        else:
            return db.command(command_name)

    def _parse_method_parameters(self, params_str):
        if not params_str.strip():
            return []

        params = self._split_function_parameters(params_str)
        parsed_params = []

        for param in params:
            param = param.strip().replace("'", '"')
            if param == '{}' or param == '':
                parsed_params.append({})
            else:
                try:
                    parsed_params.append(self._safe_json_parse(param))
                except ValueError:
                    try:
                        fixed_param = re.sub(r'(\$\w+):', r'"\1":', param)
                        parsed_params.append(self._safe_json_parse(fixed_param))
                    except ValueError as e:
                        raise ValueError(f"Invalid parameter: {param}") from e
        return parsed_params

    def _build_collection_command(self, collection_name, method_name, params):
        if method_name in COLLECTION_METHOD_MAPPING:
            return COLLECTION_METHOD_MAPPING[method_name](collection_name, params)
        return None

    def connect(self, service, user, password, database=None):
        try:
            if service.startswith('mongodb+srv://') or service.startswith('mongodb://'):
                connection_string = service

                if '{USERNAME}' in connection_string and '{PASSWORD}' in connection_string:
                    if not user or not password:
                        return self.make_error_response(400, "Username and password are required for this connection string")
                    connection_string = connection_string.replace('{USERNAME}', user).replace('{PASSWORD}', password)
                if database:
                    match = re.search(PATTERNS['connection_db'], connection_string)
                    if match:
                        connection_string = connection_string.replace(match.group(0), match.group(1) + database)
                    else:
                        if '?' in connection_string:
                            connection_string = connection_string.replace('?', f'/{database}?')
                        else:
                            connection_string += f'/{database}'
            else:
                if user and password:
                    connection_string = CONNECTION_TEMPLATES['with_auth'].format(
                        user=user, password=password, service=service,
                        database=database or DEFAULT_DATABASES['general']
                    )
                else:
                    connection_string = CONNECTION_TEMPLATES['without_auth'].format(
                        service=service, database=database or DEFAULT_DATABASES['general']
                    )
            self.mongodb_client = MongoClient(
                connection_string,
                serverSelectionTimeoutMS=30000,
                connectTimeoutMS=30000,
                socketTimeoutMS=30000,
                retryWrites=False,
                w=1
            )
            self.mongodb_client.admin.command('ping')

            connection_type = "unknown"
            if "readPreference=primaryPreferred" in connection_string:
                connection_type = "primary"
            elif "readPreference=secondaryPreferred" in connection_string:
                connection_type = "secondary"
            elif "readPreference=secondary" in connection_string and "nodeType:analytics" in connection_string:
                connection_type = "analytics"
            return self.make_connection_response(f"Successfully connected to MongoDB ({connection_type})")
        except ConnectionFailure as e:
            return self.make_error_response(500, f"Internal Server Error, Failed to connect to MongoDB: Connection failed - {e}")
        except Exception as e:
            return self.make_error_response(500, f"Internal Server Error, Failed to connect to MongoDB: {e}")

    def execute_query(self, query, database=None):
        try:
            db = self.mongodb_client[database or 'test']

            query_stripped = query.strip()
            if not query_stripped.startswith('{'):
                if (query_stripped.startswith('db.') or
                    query_stripped.startswith('rs.') or
                    query_stripped.startswith('sh.') or
                    query_stripped.startswith('show ')):

                    query = self._parse_mongo_shell_syntax(query_stripped)
                    if isinstance(query, dict):
                        query_obj = query
                    else:
                        return query
                else:
                    return self.make_error_response(400, "Invalid query format. Use JSON format or MongoDB shell syntax (e.g., db.collection.find({}), show dbs, rs.status(), sh.status())")
            else:
                query_obj = self._safe_json_parse(query)
            return self._execute_mongodb_command(db, query_obj)

        except (json.JSONDecodeError, ValueError) as e:
            return self.make_error_response(400, f"Invalid JSON format: {e}")
        except Exception as e:
            return self.make_error_response(500, f"Internal Server Error, Please check the query and try again: {e}")

    def _execute_mongodb_command(self, db, query_obj):
        if 'find' in query_obj:
            collection = db[query_obj['find']]
            filter_query = query_obj.get('filter', {})
            results = list(collection.find(filter_query))

            for doc in results:
                if '_id' in doc:
                    doc['_id'] = str(doc['_id'])

            columns_names = list(results[0].keys()) if results else []
            return self.make_success_response(
                columns_names=columns_names,
                results=[[str(doc.get(col, '')) for col in columns_names] for doc in results]
            )
        elif 'insert' in query_obj:
            collection = db[query_obj['insert']]
            documents = query_obj['documents']
            result = collection.insert_many(documents)
            return self.make_success_response(
                results=f"{len(result.inserted_ids)} documents inserted"
            )
        elif 'update' in query_obj:
            collection = db[query_obj['update']]
            updates = query_obj['updates']
            total_modified = 0
            for update in updates:
                result = collection.update_many(update['q'], update['u'])
                total_modified += result.modified_count
            return self.make_success_response(
                results=f"{total_modified} documents updated"
            )
        elif 'delete' in query_obj:
            collection = db[query_obj['delete']]
            deletes = query_obj['deletes']
            total_deleted = 0
            for delete in deletes:
                result = collection.delete_many(delete['q'])
                total_deleted += result.deleted_count
            return self.make_success_response(
                results=f"{total_deleted} documents deleted"
            )

        # Simple command handlers using generic processor
        elif query_obj.get('buildInfo'):
            return self._process_command_result(db, 'buildInfo', processor_key='buildInfo')
        elif query_obj.get('serverStatus'):
            return self._process_command_result(db, 'serverStatus', processor_key='serverStatus')
        elif query_obj.get('currentOp'):
            return self._process_command_result(db, 'currentOp', processor_key='currentOp')
        elif query_obj.get('usersInfo'):
            result = db.command('usersInfo', query_obj['usersInfo'])
            users = result.get('users', [])
            return self.make_success_response(
                columns_names=["user", "database", "roles"],
                results=[[user.get('user', ''), user.get('db', ''), str(user.get('roles', []))] for user in users]
            )
        elif query_obj.get('connectionStatus'):
            return self._process_command_result(db, 'connectionStatus', processor_key='connectionStatus')
        elif 'getParameter' in query_obj:
            if len(query_obj) == 1 and query_obj['getParameter'] == 1:
                return self.make_success_response(
                    columns_names=["info"],
                    results=[["getParameter requires specific parameter name or use getParameter: '*' for all"]]
                )
            elif 'getParameter' in query_obj and len(query_obj) > 1:
                param_dict = {k: v for k, v in query_obj.items() if k != 'getParameter'}
                result = db.command('getParameter', **param_dict)
            else:
                result = db.command('getParameter', query_obj['getParameter'])

            params = []
            for key, value in result.items():
                if key not in ['ok', '$clusterTime', 'operationTime']:
                    params.append([key, str(value)])

            return self.make_success_response(
                columns_names=["parameter", "value"],
                results=params[:20]
            )
        elif 'count' in query_obj:
            collection = db[query_obj['count']]
            query_filter = query_obj.get('query', {})
            count = collection.count_documents(query_filter)
            return self._format_simple_result('count', [[count]])
        elif query_obj.get('listDatabases'):
            return self._format_simple_result('database_list',
                [[db_name] for db_name in self.mongodb_client.list_database_names()])
        elif query_obj.get('listCollections'):
            return self._format_simple_result('collection_list',
                [[col] for col in db.list_collection_names()])
        elif query_obj.get('rolesInfo'):
            result = db.command('rolesInfo', query_obj['rolesInfo'])
            roles = result.get('roles', [])
            return self.make_success_response(
                columns_names=["role", "database", "privileges"],
                results=[[role.get('role', ''), role.get('db', ''), str(role.get('privileges', []))] for role in roles]
            )
        elif query_obj.get('getLog'):
            result = db.command('getLog', query_obj['getLog'])
            return self.make_success_response(
                columns_names=["log_entries"],
                results=[[log] for log in result.get('log', [])[-10:]]
            )
        elif 'killOp' in query_obj:
            result = db.command(query_obj)
            return self.make_success_response(
                columns_names=["info", "ok"],
                results=[["Operation terminated" if result.get('ok') else "Failed to terminate operation", result.get('ok', 0)]]
            )
        elif 'create' in query_obj:
            collection_name = query_obj['create']
            options = {k: v for k, v in query_obj.items() if k != 'create'}
            result = db.create_collection(collection_name, **options)
            return self.make_success_response(
                columns_names=["result"],
                results=[["Collection created successfully"]]
            )
        elif 'drop' in query_obj:
            collection_name = query_obj['drop']
            db.drop_collection(collection_name)
            return self.make_success_response(
                columns_names=["result"],
                results=[["Collection dropped successfully"]]
            )
        elif 'dropDatabase' in query_obj:
            self.mongodb_client.drop_database(db.name)
            return self.make_success_response(
                columns_names=["result"],
                results=[["Database dropped successfully"]]
            )
        elif 'createIndexes' in query_obj:
            collection_name = query_obj['createIndexes']
            collection = db[collection_name]
            indexes = query_obj['indexes']
            for index_spec in indexes:
                collection.create_index(list(index_spec['key'].items()))
            return self.make_success_response(
                columns_names=["result"],
                results=[["Index created successfully"]]
            )
        elif 'dropIndexes' in query_obj:
            collection_name = query_obj['dropIndexes']
            collection = db[collection_name]
            index_spec = query_obj['index']

            if isinstance(index_spec, dict):
                index_list = list(index_spec.items())
                collection.drop_index(index_list)
            elif isinstance(index_spec, str):
                collection.drop_index(index_spec)
            else:
                collection.drop_index(index_spec)

            return self.make_success_response(
                columns_names=["result"],
                results=[["Index dropped successfully"]]
            )
        elif 'listIndexes' in query_obj:
            collection_name = query_obj['listIndexes']
            collection = db[collection_name]
            indexes = list(collection.list_indexes())
            return self.make_success_response(
                columns_names=["name", "key", "unique"],
                results=[[idx.get('name', ''), str(idx.get('key', {})), idx.get('unique', False)] for idx in indexes]
            )
        elif 'aggregate' in query_obj:
            collection_name = query_obj['aggregate']
            collection = db[collection_name]
            pipeline = query_obj.get('pipeline', [])

            if pipeline and '$changeStream' in str(pipeline[0]):
                return self.make_success_response(
                    columns_names=["status"],
                    results=[["Change stream initialized for collection: " + collection_name]]
                )
            try:
                cursor = collection.aggregate(pipeline)
                results = list(cursor)

                if results:
                    columns_names = list(results[0].keys()) if results[0] else []
                    for doc in results:
                        if '_id' in doc:
                            doc['_id'] = str(doc['_id'])

                    return self.make_success_response(
                        columns_names=columns_names,
                        results=[[str(doc.get(col, '')) for col in columns_names] for doc in results]
                    )
                else:
                    return self.make_success_response(
                        columns_names=["result"],
                        results=[["No results from aggregation"]]
                    )
            except Exception as e:
                return self.make_error_response(500, f"Aggregation error: {e}")
        elif query_obj.get('replSetGetConfig'):
            result = db.command('replSetGetConfig')
            config = result.get('config', {})
            return self.make_success_response(
                columns_names=["_id", "version", "members_count"],
                results=[[config.get('_id', ''), config.get('version', 0), len(config.get('members', []))]]
            )
        elif query_obj.get('replSetGetStatus'):
            result = db.command('replSetGetStatus')
            return self.make_success_response(
                columns_names=["name", "state", "stateStr", "health"],
                results=[[m.get('name', ''), m.get('state', 0), m.get('stateStr', ''), m.get('health', 0)] for m in result.get('members', [])]
            )
        elif any(cmd in query_obj for cmd in ['collMod', 'validate', 'collStats', 'dbStats', 'hostInfo']):
            result = db.command(query_obj if 'dbStats' in query_obj else query_obj)

            if 'collMod' in query_obj:
                return self.make_success_response(columns_names=["result"], results=[["Collection modified successfully"]])
            elif 'validate' in query_obj:
                return self.make_success_response(columns_names=["valid", "nrecords"], results=[[result.get('valid', False), result.get('nrecords', 0)]])
            elif 'collStats' in query_obj:
                return self.make_success_response(columns_names=["ns", "count", "size"], results=[[result.get('ns', ''), result.get('count', 0), result.get('size', 0)]])
            elif 'dbStats' in query_obj:
                return self.make_success_response(columns_names=["db", "collections", "objects"], results=[[result.get('db', ''), result.get('collections', 0), result.get('objects', 0)]])
            elif 'hostInfo' in query_obj:
                system = result.get('system', {})
                return self.make_success_response(columns_names=["hostname", "numCores"], results=[[system.get('hostname', ''), system.get('numCores', 0)]])
        elif 'connPoolStats' in query_obj:
            result = db.command('connPoolStats')
            pools = result.get('pools', {})
            return self.make_success_response(
                columns_names=["host", "inUse", "available", "created"],
                results=[[host, pool.get('inUse', 0), pool.get('available', 0), pool.get('created', 0)] for host, pool in pools.items()]
            )
        elif 'top' in query_obj:
            result = db.command('top')
            totals = result.get('totals', {})

            processed_results = []
            for coll, stats in list(totals.items())[:10]:
                if isinstance(stats, dict) and 'total' in stats:
                    total_info = stats.get('total', {})
                    if isinstance(total_info, dict):
                        time_val = total_info.get('time', 0)
                        count_val = total_info.get('count', 0)
                    else:
                        time_val = 0
                        count_val = 0
                else:
                    time_val = 0
                    count_val = 0
                processed_results.append([coll, time_val, count_val])

            return self.make_success_response(
                columns_names=["collection", "total_time", "count"],
                results=processed_results
            )

        else:
            return self.make_error_response(400, "Unsupported MongoDB query format. Use JSON format with operations like find, insert, update, delete or MongoDB shell syntax (e.g., db.collection.find({}), show dbs, rs.status(), sh.status())")

    def _parse_mongo_shell_syntax(self, query):
        try:
            query = query.strip()
            collection_match = re.match(PATTERNS['collection'], query)
            db_match = re.match(PATTERNS['db'], query)
            rs_match = re.match(PATTERNS['rs'], query)
            sh_match = re.match(PATTERNS['sh'], query)

            show_match = re.match(PATTERNS['show'], query)

            if collection_match:
                collection_name = collection_match.group(1)
                method_name = collection_match.group(2)
                params_str = collection_match.group(3).strip()
                return self._parse_collection_method(collection_name, method_name, params_str)
            elif db_match:
                method_name = db_match.group(1)
                params_str = db_match.group(2).strip()
                return self._parse_db_method(method_name, params_str)
            elif rs_match:
                method_name = rs_match.group(1)
                params_str = rs_match.group(2).strip()
                return self._parse_rs_method(method_name, params_str)
            elif sh_match:
                method_name = sh_match.group(1)
                params_str = sh_match.group(2).strip()
                return self._parse_sh_method(method_name, params_str)
            elif show_match:
                command = show_match.group(1)
                return self._parse_show_command(command)
            else:
                return self.make_error_response(400, "Invalid shell syntax. Expected format: db.collection.method(params), db.method(params), rs.method(params), sh.method(params), or show <command>")

        except Exception as e:
            return self.make_error_response(500, f"Error parsing shell syntax: {str(e)}")

    def _parse_collection_method(self, collection_name, method_name, params_str):
        params = self._parse_method_parameters(params_str)

        command = self._build_collection_command(collection_name, method_name, params)
        if command:
            return command
        elif method_name == 'drop':
            return {
                "drop": collection_name
            }
        elif method_name == 'createIndex':
            index_spec = params[0] if params else {}
            options = params[1] if len(params) > 1 else {}
            return {
                "createIndexes": collection_name,
                "indexes": [{"key": index_spec, "name": f"idx_{collection_name}_{hash(str(index_spec))}", **options}]
            }
        elif method_name == 'dropIndex':
            return {
                "dropIndexes": collection_name,
                "index": params[0] if params else {}
            }
        elif method_name == 'getIndexes':
            return {
                "listIndexes": collection_name
            }
        elif method_name == 'watch':
            return {
                "aggregate": collection_name,
                "pipeline": [{"$changeStream": {}}],
                "cursor": {}
            }
        elif method_name == 'aggregate':
            pipeline = params[0] if params else []
            cursor_options = params[1] if len(params) > 1 and isinstance(params[1], dict) else {}
            return {
                "aggregate": collection_name,
                "pipeline": pipeline,
                "cursor": cursor_options
            }
        else:
            return self.make_error_response(400, f"Unsupported collection method: {method_name}")

    def _parse_db_method(self, method_name, params_str):
        if method_name in DB_METHOD_MAPPING:
            if method_name == 'adminCommand' or method_name == 'runCommand':
                if params_str:
                    try:
                        param_clean = params_str.strip()
                        if param_clean.startswith('{') and param_clean.endswith('}'):
                            fixed_param = param_clean.replace("'", '"')
                            fixed_param = re.sub(r'(\w+):', r'"\1":', fixed_param)
                            fixed_param = re.sub(r'(\$\w+):', r'"\1":', fixed_param)
                            return json.loads(fixed_param)
                        else:
                            return self.make_error_response(400, f"{method_name} parameter must be a JSON object")
                    except json.JSONDecodeError as e:
                        return self.make_error_response(400, f"Invalid JSON in {method_name}: {str(e)}")
                    except Exception as e:
                        return self.make_error_response(400, f"Error parsing {method_name}: {str(e)}")
                return self.make_error_response(400, f"{method_name} requires a command object parameter")
            elif method_name == 'killOp':
                try:
                    param_clean = params_str.strip()

                    if param_clean.startswith('{') and param_clean.endswith('}'):
                        try:
                            parsed_param = json.loads(param_clean.replace("'", '"'))
                            if 'op' in parsed_param:
                                opid = int(parsed_param['op'])
                            else:
                                return self.make_error_response(400, "killOp JSON parameter must have 'op' field")
                        except (json.JSONDecodeError, ValueError) as e:
                            return self.make_error_response(400, f"Invalid JSON parameter for killOp: {param_clean}")
                    elif param_clean.startswith('"') and param_clean.endswith('"'):
                        try:
                            opid = int(param_clean.strip('"'))
                        except ValueError:
                            return self.make_error_response(400, f"Operation ID must be numeric: {param_clean}")
                    elif param_clean.startswith("'") and param_clean.endswith("'"):
                        try:
                            opid = int(param_clean.strip("'"))
                        except ValueError:
                            return self.make_error_response(400, f"Operation ID must be numeric: {param_clean}")
                    else:
                        try:
                            opid = int(param_clean)
                        except ValueError:
                            return self.make_error_response(400, f"Invalid operation ID format: {param_clean}")
                    return {
                        "killOp": 1,
                        "op": opid
                    }
                except Exception as e:
                    return self.make_error_response(400, f"Error parsing killOp parameter: {str(e)}")
            elif method_name == 'createCollection':
                try:
                    params = self._split_function_parameters(params_str)

                    collection_name = params[0].strip()
                    if collection_name.startswith('"') and collection_name.endswith('"'):
                        collection_name = collection_name.strip('"')
                    elif collection_name.startswith("'") and collection_name.endswith("'"):
                        collection_name = collection_name.strip("'")

                    options = {}
                    if len(params) > 1:
                        try:
                            options_str = params[1].strip().replace("'", '"')
                            options = json.loads(options_str)
                        except json.JSONDecodeError:
                            return self.make_error_response(400, f"Invalid options parameter for createCollection: {params[1]}")

                    return {
                        "create": collection_name,
                        **options
                    }
                except Exception as e:
                    return self.make_error_response(400, f"Error parsing createCollection parameter: {str(e)}")
            else:
                return DB_METHOD_MAPPING[method_name]
        else:
            return self.make_error_response(400, f"Unsupported database method: {method_name}")

    def _parse_show_command(self, command):
        if command in SHOW_MAPPING:
            return SHOW_MAPPING[command]
        else:
            return self.make_error_response(400, f"Unsupported show command: {command}. Supported: {', '.join(SHOW_MAPPING.keys())}")

    def _parse_rs_method(self, method_name, params_str):
        if method_name in RS_METHOD_MAPPING:
            command = RS_METHOD_MAPPING[method_name]
            if command is None:
                return self.make_error_response(400, f"rs.{method_name} requires parameters - use JSON format instead")
            return command
        else:
            return self.make_error_response(400, f"Unsupported replica set method: {method_name}")

    def _parse_sh_method(self, method_name, params_str):
        if method_name in SH_METHOD_MAPPING:
            command = SH_METHOD_MAPPING[method_name]
            if command is None:
                if method_name == 'enableSharding':
                    try:
                        db_name = params_str.strip()
                        if db_name.startswith('"') and db_name.endswith('"'):
                            db_name = db_name.strip('"')
                        elif db_name.startswith("'") and db_name.endswith("'"):
                            db_name = db_name.strip("'")
                        return {
                            "enableSharding": db_name
                        }
                    except Exception as e:
                        return self.make_error_response(400, f"Error parsing enableSharding parameter: {str(e)}")
                else:
                    return self.make_error_response(400, f"sh.{method_name} requires parameters - use JSON format instead")
            return command
        else:
            return self.make_error_response(400, f"Unsupported sharding method: {method_name}")

    def _split_function_parameters(self, params_str):
        if not params_str.strip():
            return []

        params = []
        current_param = ""
        brace_count = 0
        bracket_count = 0
        in_quotes = False
        quote_char = None

        for char in params_str:
            if char in ['"', "'"] and not in_quotes:
                in_quotes = True
                quote_char = char
            elif char == quote_char and in_quotes:
                in_quotes = False
                quote_char = None
            elif not in_quotes:
                if char == '{':
                    brace_count += 1
                elif char == '}':
                    brace_count -= 1
                elif char == '[':
                    bracket_count += 1
                elif char == ']':
                    bracket_count -= 1
                elif char == ',' and brace_count == 0 and bracket_count == 0:
                    params.append(current_param.strip())
                    current_param = ""
                    continue
            current_param += char
        if current_param.strip():
            params.append(current_param.strip())
        return params