# Copyright 2012 OpenStack Foundation
# Copyright 2013 IBM Corp.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import subprocess
import allure

from oslo_log import log

from common.utils import net_utils
from configs import config
from lib.common.utils.linux import vqa_test_node_client
from lib.common.utils import test_utils
from lib.common import api_version_utils
from lib import exceptions as lib_exc
import tests.base

CONF = config.CONF

LOG = log.getLogger(__name__)

LATEST_MICROVERSION = "latest"


class ScenarioTest(tests.base.BaseTestCase):
    """Base class for scenario tests. Uses own clients."""

    credentials = ["primary", "admin"]

    compute_min_microversion = None
    compute_max_microversion = LATEST_MICROVERSION
    volume_min_microversion = None
    volume_max_microversion = LATEST_MICROVERSION
    placement_min_microversion = None
    placement_max_microversion = LATEST_MICROVERSION

    @classmethod
    def skip_checks(cls):
        super().skip_checks()
        api_version_utils.check_skip_with_microversion(
            cls.compute_min_microversion,
            cls.compute_max_microversion,
            CONF.compute.min_microversion,
            CONF.compute.max_microversion,
        )
        api_version_utils.check_skip_with_microversion(
            cls.volume_min_microversion,
            cls.volume_max_microversion,
            CONF.volume.min_microversion,
            CONF.volume.max_microversion,
        )
        api_version_utils.check_skip_with_microversion(
            cls.placement_min_microversion,
            cls.placement_max_microversion,
            CONF.placement.min_microversion,
            CONF.placement.max_microversion,
        )

    @classmethod
    def resource_setup(cls):
        super().resource_setup()
        cls.compute_request_microversion = (
            api_version_utils.select_request_microversion(
                cls.compute_min_microversion, CONF.compute.min_microversion
            )
        )
        cls.volume_request_microversion = api_version_utils.select_request_microversion(
            cls.volume_min_microversion, CONF.volume.min_microversion
        )
        cls.placement_request_microversion = (
            api_version_utils.select_request_microversion(
                cls.placement_min_microversion, CONF.placement.min_microversion
            )
        )

        cls.setup_api_microversion_fixture(
            compute_microversion=cls.compute_request_microversion,
            volume_microversion=cls.volume_request_microversion,
            placement_microversion=cls.placement_request_microversion,
        )

    @classmethod
    def setup_credentials(cls):
        # Setting network=True, subnet=True creates a default network
        cls.set_network_resources(network=True, subnet=True, router=True, dhcp=True)
        super().setup_credentials()

    def setup_compute_client(self):
        """Compute client"""
        self.compute_images_client = self.os_primary.compute_images_client
        self.keypairs_client = self.os_primary.keypairs_client
        self.servers_client = self.os_primary.servers_client
        self.interface_client = self.os_primary.interfaces_client
        self.flavors_client = self.os_primary.flavors_client

    def setup_network_client(self):
        """Neutron network client"""
        self.networks_client = self.os_primary.networks_client
        self.ports_client = self.os_primary.ports_client
        self.routers_client = self.os_primary.routers_client
        self.subnets_client = self.os_primary.subnets_client
        self.floating_ips_client = self.os_primary.floating_ips_client
        self.security_groups_client = self.os_primary.security_groups_client
        self.security_group_rules_client = self.os_primary.security_group_rules_client

    @classmethod
    def setup_clients(cls):
        """This setup the service clients for the tests"""
        super().setup_clients()
        if CONF.service_available.glance:
            if CONF.image_feature_enabled.api_v2:
                cls.image_client = cls.os_primary.image_client_v2
            else:
                raise lib_exc.InvalidConfiguration(
                    "api_v2 must be True in [image-feature-enabled]."
                )

        cls.setup_compute_client(cls)
        cls.setup_network_client(cls)
        if CONF.service_available.cinder:
            cls.volumes_client = cls.os_primary.volumes_client_latest
            cls.snapshots_client = cls.os_primary.snapshots_client_latest
            cls.backups_client = cls.os_primary.backups_client_latest

    @allure.step("Ping IP address")
    def ping_ip_address(
        self, ip_address, should_succeed=True, ping_timeout=None, mtu=None
    ):
        """ping ip address"""
        timeout = ping_timeout or CONF.validation.ping_timeout
        cmd = ["ping", "-c1"]

        if mtu:
            cmd += [
                # don't fragment
                "-M",
                "do",
                # ping receives just the size of ICMP payload
                "-s",
                str(net_utils.get_ping_payload_size(mtu, 4)),
            ]
        cmd.append(ip_address)

        def ping():
            with subprocess.Popen(
                cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
            ) as proc:
                proc.communicate()

            return (proc.returncode == 0) == should_succeed

        caller = test_utils.find_test_caller()
        LOG.debug(
            "%(caller)s begins to ping %(ip)s in %(timeout)s sec and the"
            " expected result is %(should_succeed)s",
            {
                "caller": caller,
                "ip": ip_address,
                "timeout": timeout,
                "should_succeed": "reachable" if should_succeed else "unreachable",
            },
        )
        result = test_utils.call_until_true(ping, timeout, 1)
        LOG.debug(
            "%(caller)s finishes ping %(ip)s in %(timeout)s sec and the "
            "ping result is %(result)s",
            {
                "caller": caller,
                "ip": ip_address,
                "timeout": timeout,
                "result": "expected" if result else "unexpected",
            },
        )
        return result

    def check_vm_connectivity(
        self,
        remote_client,
        ip_address,
        username=None,
        should_connect=True,
        extra_msg="",
        mtu=None,
    ):
        """Check server connectivity

        :param remote_client: RemoteClient object
        :param ip_address: server to test against
        :param username: server's ssh username
        :param should_connect: True/False indicates positive/negative test
            positive - attempt ping and ssh
            negative - attempt ping and fail if succeed
        :param extra_msg: Message to help with debugging if ``ping_ip_address``
            fails
        :param mtu: network MTU to use for connectivity validation

        :raises: AssertError if the result of the connectivity check does
            not match the value of the should_connect param
        """

        if username is None:
            username = CONF.validation.image_ssh_user
        LOG.debug(
            "checking network connections to IP %s with user: %s", ip_address, username
        )
        if should_connect:
            msg = f"Timed out waiting for {ip_address} to become reachable"
        else:
            msg = f"ip address {ip_address} is reachable"
        if extra_msg:
            msg = f"{extra_msg}\n{msg}"
        assert self.ping_ip_address(
            ip_address, should_succeed=should_connect, mtu=mtu
        ), msg

        if should_connect:
            # no need to check ssh for negative connectivity
            try:
                remote_client.execute_ssh_connection(ip_address, username)
            except Exception:
                if not extra_msg:
                    extra_msg = f"Failed to ssh to {ip_address}"
                LOG.exception(extra_msg)
                raise

    def get_remote_client(
        self,
        is_vpc_connected=False,
        net_conn_test_client_ip=None,
        client_hostname=None,
        client_ip=None,
    ):
        """Get a SSH client to a remote server

        :param is_vpc_connected: Whether the remote server is connected to VPC network
        :param net_conn_test_client_ip: ip address of client VM
        :client_hostname: hostname of the client
        :client_ip: ip address of the client
        :return: a RemoteClient object
        """
        if is_vpc_connected:
            test_node_url = client_hostname
            net_conn_test_client_ip = client_ip
        else:
            test_node_url = CONF.validation.test_node_url_on_non_vpc

        return vqa_test_node_client.VQATestNodeClient(
            test_node_url, net_conn_test_client_ip
        )

    def check_network_connectivity(
        self,
        protocol,
        remote_client,
        remote_server_ip,
        username=None,
        should_connect=True,
        extra_msg="",
        mtu=None,
    ):
        """
        Validates network connectivity to a specified IP address based on the given protocol.

        Attempts to establish a connection using the specified protocol and raises an exception
        if the outcome is not as expected based on 'should_connect' flag.

        :param remote_client: The client used to test the network connection.
        :param remote_server_ip: The target IP address for the connection.
        :param protocol: The protocol to use for the connection ('ssh', 'icmp', 'tcp', 'udp').
        :param username: Optional username for SSH connection.
        :param should_connect: Flag to indicate if the connection should succeed or not.
        :param extra_msg: Optional additional message for the exception.
        :param mtu: Network MTU to use for connectivity validation (for TCP/UDP).
        :raises Exception: If connectivity does not match 'should_connect' expectation.
        """

        def log_and_raise(exception_msg):
            LOG.error(exception_msg)
            raise Exception(exception_msg)

        if protocol in ["tcp", "udp"]:
            if should_connect:
                msg = f"Timed out waiting for {remote_server_ip} to become reachable with {protocol} protocol"
            else:
                msg = f"IP address {remote_server_ip} is reachable"
            if extra_msg:
                msg = f"{extra_msg}\n{msg}"
            assert self.ping_ip_address(
                remote_server_ip, should_succeed=should_connect, mtu=mtu
            ), msg

        protocol_methods = {
            "ssh": lambda: remote_client.execute_ssh_connection(
                remote_server_ip, username
            ),
            "icmp": lambda: remote_client.execute_icmp_connection(remote_server_ip),
            "tcp": lambda: remote_client.execute_tcp_connection(remote_server_ip),
            "udp": lambda: remote_client.execute_udp_connection(remote_server_ip),
        }

        try:
            if protocol not in protocol_methods:
                raise ValueError(f"Unsupported protocol: {protocol}")

            protocol_methods[protocol]()

            if not should_connect:
                extra_msg = (
                    extra_msg
                    or f"Unexpected {protocol.upper()} connectivity to {remote_server_ip}"
                )
                log_and_raise(extra_msg)

        except Exception as e:
            if should_connect:
                extra_msg = (
                    extra_msg
                    or f"Failed to connect to {remote_server_ip} using the {protocol.upper()} protocol: {e}"
                )
                LOG.exception(extra_msg)
                raise


class NetworkScenarioTest(ScenarioTest):
    """Base class for network scenario tests.

    This class provide helpers for network scenario tests, using the neutron
    API. Helpers from ancestor which use the nova network API are overridden
    with the neutron API.

    This Class also enforces using Neutron instead of novanetwork.
    Subclassed tests will be skipped if Neutron is not enabled

    """

    @classmethod
    def skip_checks(cls):
        super().skip_checks()
        if not CONF.service_available.neutron:
            raise cls.skipException("Neutron not available")
