import os
import re
import subprocess
import shutil
import logging
import socket
import threading
import time
from datetime import datetime, timezone
from typing import Optional
from urllib.parse import urlparse

from yj.athenz.client import ZTSClient
from common.cert_metrics import set_cert_expiry_metric

# [flava-]{product-abbr}.{environment}.{flava-project}
SERVICE_PROVIDER_PATTERN = (
    r"^(?:flava-)?(?P<product>[\w-]+)\.(?P<env>[\w-]+)\.(?P<project>[\w-]+)$"
)

DEFAULT_PROXY_URL = "http://lypwg01-001.yahoo-net.jp:8080"
DEFAULT_ZTS_HOST = "apj.zts.athenz.yahoo.co.jp"
DEFAULT_ZTS_PORT = 4443
DEFAULT_ZTS_ENDPOINT = "https://apj.zts.athenz.yahoo.co.jp:4443/zts/v1"
DEFAULT_CERT_RENEWAL_THRESHOLD_HOURS = 24 * 7  # 1 Week
DEFAULT_CERT_CHECK_INTERVAL_SECONDS = 300  # Check every 5 minutes

logger = logging.getLogger(__name__)


class AthenzClient:
    def __init__(
        self,
        account_name,
        provider_domain,
        key_path=None,
        cert_path=None,
        use_proxy=True,
    ):
        """
        Initialize Athenz client with flexible path and proxy configuration.

        Args:
            account_name: Account name for authentication.
            provider_domain: Provider domain (e.g., flava-faas.stage.flava-qa or servicemap.stage.flava-api-test).
            key_path: Optional custom path to private key file.
            cert_path: Optional custom path to certificate file.
            use_proxy: Whether to use proxy for connections (default: True).
        """
        self.provider_domain = provider_domain
        self.account_name = account_name

        # Validate provider_domain format

        match = re.fullmatch(SERVICE_PROVIDER_PATTERN, provider_domain)
        if not match:
            raise ValueError(
                f"Invalid provider_domain format: {provider_domain}. "
                f"Expected: [flava-]{{product-abbr}}.{{environment}}.{{flava-project}}"
            )

        self.product_abbr = match.group("product")
        self.environment = match.group("env")
        self.project_name = match.group("project")

        # Derive key_path from default path if not provided
        if key_path is None:
            if not account_name:
                raise ValueError(
                    "account_name is required when key_path is not provided"
                )
            secret_mount_path = os.path.join("/", self.environment, self.project_name)
            key_path = os.path.join(secret_mount_path, f"{account_name}.key.pem")
            if cert_path is None:
                cert_path = os.path.join(secret_mount_path, f"{account_name}.cert.pem")

        # Private key must exist and be readable at startup
        if not os.path.exists(key_path):
            raise FileNotFoundError(f"Private key not found: {key_path}")
        if not os.access(key_path, os.R_OK):
            raise PermissionError(f"No read permission for private key: {key_path}")

        self.key_path = key_path
        # Origin cert path (e.g., ConfigMap). It may not exist at startup.
        self.cert_path = cert_path

        # Proxy & ZTS configuration
        self.proxy_url = (
            os.getenv("PROXY_URL", DEFAULT_PROXY_URL) if use_proxy else None
        )
        self._init_zts_config()

        # Configure certificate check/renewal interval
        self._cert_check_interval = int(
            os.getenv(
                "CERT_CHECK_INTERVAL_SECONDS",
                str(DEFAULT_CERT_CHECK_INTERVAL_SECONDS),
            )
        )
        self._last_cert_check_ts = 0.0
        self._cert_lock = threading.Lock()

        # 1) Prepare certificate once at pod startup
        renewed = self._check_and_renew()
        self._last_cert_check_ts = time.time()
        logger.debug("Initial certificate preparation completed (renewed=%s)", renewed)

        # 2) Initialize ZTSClient with the prepared certificate
        self._init_zts_client()

    # ------------------------------------------------------------------
    # ZTS configuration & client initialization
    # ------------------------------------------------------------------
    def _init_zts_config(self):
        """Initialize ZTS host/port/endpoint from environment in a consistent way."""
        zts_endpoint_env = os.getenv("ZTS_ENDPOINT")
        zts_host_env = os.getenv("ZTS_HOST")
        zts_port_env = os.getenv("ZTS_PORT")

        if zts_endpoint_env:
            parsed = urlparse(zts_endpoint_env)
            host = parsed.hostname or DEFAULT_ZTS_HOST
            port = parsed.port or DEFAULT_ZTS_PORT
            self.zts_host = host
            self.zts_port = port
            self.zts_endpoint = zts_endpoint_env
        else:
            host = zts_host_env or DEFAULT_ZTS_HOST
            port = int(zts_port_env) if zts_port_env else DEFAULT_ZTS_PORT
            self.zts_host = host
            self.zts_port = port
            if host == DEFAULT_ZTS_HOST and port == DEFAULT_ZTS_PORT:
                self.zts_endpoint = DEFAULT_ZTS_ENDPOINT
            else:
                self.zts_endpoint = f"https://{host}:{port}/zts/v1"

    def _init_zts_client(self):
        """(Re)initialize ZTSClient with current cert/key and ZTS configuration."""
        try:
            zts_client_kwargs = dict(
                host=self.zts_host,
                port=self.zts_port,
                token_refresh_offset=3 * 60 * 60,
            )
            if self.proxy_url:
                zts_client_kwargs["proxy_url"] = self.proxy_url

            self.zts_client = ZTSClient.with_cert(
                self.cert_path,
                self.key_path,
                **zts_client_kwargs,
            )
        except Exception as e:
            logger.error(
                "Failed to initialize ZTSClient "
                "(cert_path=%s, key_path=%s, proxy_url=%s, host=%s, port=%s): %s",
                self.cert_path,
                self.key_path,
                self.proxy_url,
                self.zts_host,
                self.zts_port,
                e,
            )
            raise RuntimeError(
                "Failed to initialize ZTSClient "
                f"(cert_path={self.cert_path}, key_path={self.key_path}, "
                f"proxy_url={self.proxy_url}, host={self.zts_host}, port={self.zts_port}): {e}"
            ) from e

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    def get_access_token(self):
        """Get access token for the configured provider domain."""
        # Periodically check and renew certificate when issuing tokens
        self._ensure_cert_ready()

        logger.debug(
            "[AthenzClient] get_access_token() called for provider_domain=%s",
            self.provider_domain,
        )
        try:
            access_token = self.zts_client.get_access_token(self.provider_domain)
            logger.debug(
                "[AthenzClient] access_token fetched successfully for provider_domain=%s",
                self.provider_domain,
            )
            return access_token.access_token
        except Exception as e:
            logger.error(
                "[AthenzClient] Failed to get access token for provider_domain=%s: %s",
                self.provider_domain,
                e,
            )
            raise RuntimeError(
                f"Failed to get access token for provider_domain={self.provider_domain}: {e}"
            ) from e

    # ------------------------------------------------------------------
    # Certificate check/renewal logic
    # ------------------------------------------------------------------
    def _ensure_cert_ready(self):
        """
        Periodically check and renew the certificate when access tokens are requested.

        This method is called on each access token request but throttles the
        actual check/renewal using CERT_CHECK_INTERVAL_SECONDS.
        """
        now = time.time()
        # Fast path: skip if the check interval has not elapsed
        if now - self._last_cert_check_ts < self._cert_check_interval:
            return

        with self._cert_lock:
            # Double-check after acquiring the lock (another thread may have updated it)
            now = time.time()
            if now - self._last_cert_check_ts < self._cert_check_interval:
                return

            logger.debug(
                "Running periodic certificate check (interval=%s, last_check=%s)",
                self._cert_check_interval,
                self._last_cert_check_ts,
            )

            renewed = self._check_and_renew()
            self._last_cert_check_ts = now

            # Reinitialize ZTSClient only when the certificate was updated
            if renewed:
                logger.info("Certificate renewed/generated. Reinitializing ZTSClient.")
                self._init_zts_client()

    def _check_and_renew(self):
        """
        Check certificate expiry and renew if needed.

        Returns:
            True if certificate was newly generated or renewed, False otherwise.
        """
        renewed = False
        try:
            logger.info("Starting certificate check and renewal process")

            athenz_dir = "/tmp/athenz"
            logger.debug("Ensuring athenz directory exists: %s", athenz_dir)
            os.makedirs(athenz_dir, exist_ok=True)
            os.chmod(athenz_dir, 0o700)

            latest_cert_path = self._get_latest_cert_path()
            origin_cert_path = self.cert_path  # May be None at startup

            logger.debug("Latest cert path: %s", latest_cert_path)
            logger.debug("Origin cert path: %s", origin_cert_path)

            # Validate origin cert if provided
            if origin_cert_path is not None and os.path.exists(origin_cert_path):
                if not os.access(origin_cert_path, os.R_OK):
                    raise PermissionError(
                        f"No read permission for origin certificate: {origin_cert_path}"
                    )

            # Private key must exist and be readable
            if not os.path.exists(self.key_path):
                raise FileNotFoundError(f"Private key not found at {self.key_path}")
            if not os.access(self.key_path, os.R_OK):
                raise PermissionError(
                    f"No read permission for private key: {self.key_path}"
                )

            if os.path.exists(latest_cert_path):
                logger.debug("Latest certificate exists, checking expiry...")
                self.cert_path = latest_cert_path

                if self._needs_renewal():
                    logger.info("Certificate needs renewal, attempting renewal...")
                    self._renew_cert()
                    renewed = True
                    if not self._validate_cert():
                        raise RuntimeError(
                            "Certificate validation failed after renewal"
                        )
                else:
                    logger.debug("Latest certificate is still valid")
            else:
                if origin_cert_path is not None and os.path.exists(origin_cert_path):
                    logger.info(
                        "Latest certificate not found, copying from origin: %s -> %s",
                        origin_cert_path,
                        latest_cert_path,
                    )
                    try:
                        shutil.copy2(origin_cert_path, latest_cert_path)
                        os.chmod(latest_cert_path, 0o600)
                        logger.debug("Certificate copied and permissions set")
                    except PermissionError as e:
                        raise RuntimeError(
                            f"Failed to copy certificate due to permissions: {e}"
                        ) from e

                    self.cert_path = latest_cert_path

                    if self._needs_renewal():
                        logger.info(
                            "Copied certificate needs renewal, attempting renewal..."
                        )
                        self._renew_cert()
                        renewed = True
                        if not self._validate_cert():
                            raise RuntimeError(
                                "Certificate validation failed after renewal"
                            )
                    else:
                        logger.debug("Copied certificate is still valid")
                else:
                    logger.info(
                        "No existing certificate found, generating new certificate..."
                    )
                    self.cert_path = latest_cert_path
                    self._renew_cert()
                    renewed = True
                    if not self._validate_cert():
                        raise RuntimeError(
                            "Certificate validation failed after generation"
                        )

            self.cert_path = latest_cert_path
            logger.info("Certificate check and renewal process completed successfully")

            # Update certificate expiry metric after every check
            self._update_cert_expiry_metric()

            return renewed

        except Exception as e:
            logger.error("Certificate check and renewal failed: %s", e)
            # Propagate the failure so that initialization/request can fail fast
            raise

    def _update_cert_expiry_metric(self):
        """
        Update certificate expiry metric based on current certificate.
        """
        labels = (
            self.account_name or "unknown",
            self.environment or "unknown",
            self.project_name or "unknown",
            self.product_abbr or "unknown",  # flava_product
        )

        if not self.cert_path or not os.path.exists(self.cert_path):
            logger.debug("Certificate does not exist, setting expiry metric to 0")
            set_cert_expiry_metric(*labels, 0)
            return

        try:
            cmd = ["openssl", "x509", "-in", self.cert_path, "-noout", "-enddate"]
            result = subprocess.run(cmd, capture_output=True, text=True)
            remaining_seconds = 0

            if result.returncode == 0:
                line = result.stdout.strip()
                if line.startswith("notAfter="):
                    not_after = line[len("notAfter=") :].strip()
                    try:
                        # Example format: "Jul  3 07:52:54 2024 GMT"
                        dt = datetime.strptime(not_after, "%b %d %H:%M:%S %Y GMT")
                        dt = dt.replace(tzinfo=timezone.utc)
                        now = datetime.utcnow().replace(tzinfo=timezone.utc)
                        remaining_seconds = int((dt - now).total_seconds())
                        if remaining_seconds < 0:
                            remaining_seconds = 0
                    except Exception as e:
                        logger.warning(
                            "Failed to parse cert notAfter date: %s, %s",
                            not_after,
                            e,
                        )

            set_cert_expiry_metric(*labels, remaining_seconds)
            logger.debug(
                "Updated cert expiry metric: %s seconds remaining", remaining_seconds
            )

        except Exception as e:
            logger.error("Failed to update cert expiry metric: %s", e)
            set_cert_expiry_metric(*labels, 0)

    def _needs_renewal(self):
        """
        Check if certificate needs renewal based on expiry time.

        Uses openssl -checkend to decide whether renewal is required.
        """
        if not self.cert_path or not os.path.exists(self.cert_path):
            logger.info("Certificate does not exist, renewal required")
            return True

        try:
            threshold_hours = int(
                os.getenv(
                    "CERT_RENEWAL_THRESHOLD_HOURS",
                    str(DEFAULT_CERT_RENEWAL_THRESHOLD_HOURS),
                )
            )
            threshold_seconds = threshold_hours * 3600

            cmd = [
                "openssl",
                "x509",
                "-in",
                self.cert_path,
                "-checkend",
                str(threshold_seconds),
                "-noout",
            ]
            result = subprocess.run(cmd, capture_output=True, text=True)

            needs_renewal = result.returncode != 0

            if needs_renewal:
                logger.info("Certificate expires within %s hours", threshold_hours)

            return needs_renewal

        except Exception as e:
            logger.error("Failed to check certificate expiry: %s", e)
            # On failure, conservatively request renewal
            return True

    def _renew_cert(self):
        """Renew certificate using zts-svccert."""
        if not shutil.which("zts-svccert"):
            raise RuntimeError(
                "zts-svccert command not found. Make sure athenz_tools_jp is installed."
            )

        domain = f"flava-iam.{self.environment}.{self.project_name}"
        service = self.account_name
        key_version = self.account_name
        pod_uid = os.getenv("POD_UID", socket.gethostname())

        logger.info(
            "Renewing certificate for domain=%s, service=%s, pod_uid=%s",
            domain,
            service,
            pod_uid,
        )

        cmd = [
            "zts-svccert",
            "-zts",
            self.zts_endpoint,
            "-domain",
            domain,
            "-service",
            service,
            "-private-key",
            self.key_path,
            "-key-version",
            key_version,
            "-cert-file",
            self.cert_path,
            "-hdr",
            "Yahoo-Principal-Auth",
            "-dns-domain",
            "zts.athenz.cloud",
            "-provider",
            "sys.auth.zts",
            "-instance",
            pod_uid,
        ]

        logger.debug("Executing: %s", " ".join(cmd))
        try:
            result = subprocess.run(cmd, capture_output=True, text=True, check=True)
            logger.info("Certificate renewal completed successfully")
            logger.debug("zts-svccert output: %s", result.stdout)
        except subprocess.CalledProcessError as e:
            logger.error("zts-svccert failed: %s", e)
            logger.error("Error output: %s", e.stderr)
            raise RuntimeError(f"Certificate renewal failed: {e}") from e
        except Exception as e:
            logger.error("Certificate renewal error: %s", e)
            raise RuntimeError(f"Certificate renewal failed: {e}") from e

    def _validate_cert(self):
        """Validate that certificate is properly formatted and readable."""
        if not self.cert_path or not os.path.exists(self.cert_path):
            logger.error("Certificate file not found at %s", self.cert_path)
            return False

        try:
            cmd = ["openssl", "x509", "-in", self.cert_path, "-noout", "-text"]
            result = subprocess.run(cmd, capture_output=True, text=True, check=True)

            if "Certificate:" in result.stdout:
                logger.debug("Certificate validation successful")
                return True

            logger.error("Certificate validation failed - invalid format")
            return False

        except Exception as e:
            logger.error("Certificate validation error: %s", e)
            return False

    def _get_latest_cert_path(self):
        """Calculate latest certificate path: /tmp/athenz/{project}.{env}.{account}.cert.pem."""
        filename = (
            f"{self.project_name}.{self.environment}.{self.account_name}.cert.pem"
        )
        latest_cert_path = os.path.join("/tmp/athenz", filename)
        logger.debug("Generated latest cert path: %s", latest_cert_path)
        return latest_cert_path
