import os import re import subprocess import shutil import logging import socket from datetime import datetime from yj.athenz.client import ZTSClient # https://wiki.workers-hub.com/pages/viewpage.action?pageId=542487541 # flava-{product-abbr}.{environment}.{flava-project} SERVICE_PROVIDER_PATTERN = r"flava-[\w-]+\.(\w+)\.([\w-]+)" DEFAULT_PROXY_URL = "http://lypwg01-001.yahoo-net.jp:8080" DEFAULT_ZTS_HOST = "apj.zts.athenz.yahoo.co.jp" DEFAULT_ZTS_PORT = 4443 DEFAULT_ZTS_ENDPOINT = "https://apj.zts.athenz.yahoo.co.jp:4443/zts/v1" DEFAULT_CERT_RENEWAL_THRESHOLD_HOURS = 24 # Setup logging logger = logging.getLogger(__name__) class AthenzClient: def __init__( self, account_name, provider_domain, key_path=None, cert_path=None, use_proxy=True, ): """ Initialize Athenz client with flexible path and proxy configuration. Args: account_name: Account name for authentication provider_domain: Provider domain (e.g., flava-faas.stage.flava-qa) key_path: Optional custom path to private key file cert_path: Optional custom path to certificate file use_proxy: Whether to use proxy for connections (default: True) """ self.provider_domain = provider_domain self.account_name = account_name # Extract environment and project from provider_domain once match = re.search(SERVICE_PROVIDER_PATTERN, provider_domain) if match: self.environment = match.group(1) self.project_name = match.group(2) else: self.environment = None self.project_name = None # If paths are not provided, derive them from provider domain if key_path is None or cert_path is None: if account_name is None: raise ValueError( "account_name is required when key_path or cert_path are not provided" ) if self.environment and self.project_name: secret_mount_path = f"/{self.environment}/{self.project_name}/" if key_path is None: key_path = os.path.join(secret_mount_path, f"{account_name}.pem") if cert_path is None: cert_path = os.path.join(secret_mount_path, f"{account_name}.crt") else: raise ValueError( "Provider domain does not match the expected pattern and no custom paths provided" ) # Set proxy configuration if use_proxy: proxy_url = os.getenv("PROXY_URL", DEFAULT_PROXY_URL) else: proxy_url = None # Set ZTS configuration zts_host = os.getenv("ZTS_HOST", DEFAULT_ZTS_HOST) zts_port = DEFAULT_ZTS_PORT # Store configuration for certificate renewal self.cert_path = cert_path self.key_path = key_path # Check and renew certificate if needed (after paths are determined) self._check_and_renew() try: if use_proxy and proxy_url: self.zts_client = ZTSClient.with_cert( cert_path, key_path, proxy_url=proxy_url, host=zts_host, port=zts_port, token_refresh_offset=3 * 60 * 60, ) else: self.zts_client = ZTSClient.with_cert( cert_path, key_path, host=zts_host, port=zts_port, token_refresh_offset=3 * 60 * 60, ) except Exception as e: raise RuntimeError( f"Failed to initialize ZTSClient with cert_path={cert_path}, " f"key_path={key_path}, proxy_url={proxy_url}, host={zts_host}, " f"port={zts_port}: {str(e)}" ) from e def get_access_token(self): try: access_token = self.zts_client.get_access_token(self.provider_domain) return access_token.access_token except Exception as e: raise RuntimeError( f"Failed to get access token for provider_domain={self.provider_domain}: {str(e)}" ) from e def _check_and_renew(self): """Check certificate expiry and renew if needed""" try: logger.info("Starting certificate check and renewal process") # Ensure athenz directory exists (use /tmp for better permissions) athenz_dir = "/tmp/athenz" logger.debug(f"Creating athenz directory: {athenz_dir}") os.makedirs(athenz_dir, exist_ok=True) os.chmod(athenz_dir, 0o755) logger.debug(f"Athenz directory created and permissions set") # Calculate latest certificate path latest_cert_path = self._get_latest_cert_path() origin_cert_path = self.cert_path # Original cert_path from __init__ (ConfigMap) logger.debug(f"Latest cert path: {latest_cert_path}") logger.debug(f"Origin cert path: {origin_cert_path}") # Check permissions on origin path early if os.path.exists(origin_cert_path): if not os.access(origin_cert_path, os.R_OK): logger.error(f"Cannot read origin certificate file: {origin_cert_path}") raise RuntimeError(f"No read permission for origin certificate: {origin_cert_path}") # Flow: Check latest cert -> Copy from ConfigMap -> Generate new cert if os.path.exists(latest_cert_path): logger.debug("Latest certificate exists, checking expiry...") # Update cert_path to latest location for renewal operations self.cert_path = latest_cert_path if self._needs_renewal(): logger.info("Certificate needs renewal, attempting renewal...") self._renew_cert() if not self._validate_cert(): raise RuntimeError("Certificate validation failed after renewal") logger.info("Certificate renewal completed successfully") else: logger.debug("Latest certificate is still valid") else: if os.path.exists(origin_cert_path): logger.info(f"Latest certificate not found, copying from ConfigMap: {origin_cert_path} -> {latest_cert_path}") try: shutil.copy2(origin_cert_path, latest_cert_path) os.chmod(latest_cert_path, 0o600) logger.debug(f"Certificate copied and permissions set successfully") except PermissionError as e: logger.error(f"Permission denied while copying certificate: {e}") raise RuntimeError(f"Failed to copy certificate due to permissions: {e}") from e # Update cert_path to latest location self.cert_path = latest_cert_path # Check if copied certificate needs renewal if self._needs_renewal(): logger.info("Copied certificate needs renewal, attempting renewal...") self._renew_cert() if not self._validate_cert(): raise RuntimeError("Certificate validation failed after renewal") logger.info("Certificate renewal completed successfully") else: logger.debug("Copied certificate is still valid") else: logger.info("No existing certificate found, generating new certificate...") # Update cert_path to latest location for new cert generation self.cert_path = latest_cert_path self._renew_cert() if not self._validate_cert(): raise RuntimeError("Certificate validation failed after generation") logger.info("New certificate generated successfully") # Final update: cert_path points to latest location, key_path remains ConfigMap self.cert_path = latest_cert_path # self.key_path remains unchanged (ConfigMap) except Exception as e: logger.error(f"Certificate check and renewal failed: {e}") # Don't raise exception here to allow fallback to existing certificate # raise RuntimeError(f"Certificate renewal failed: {e}") from e def _needs_renewal(self): """Check if certificate needs renewal based on expiry time""" try: threshold_hours = int(os.getenv('CERT_RENEWAL_THRESHOLD_HOURS', DEFAULT_CERT_RENEWAL_THRESHOLD_HOURS)) threshold_seconds = threshold_hours * 3600 cmd = ['openssl', 'x509', '-in', self.cert_path, '-checkend', str(threshold_seconds), '-noout'] result = subprocess.run(cmd, capture_output=True, text=True) # Return code 0 means certificate is still valid needs_renewal = (result.returncode != 0) if needs_renewal: logger.info(f"Certificate expires within {threshold_hours} hours") return needs_renewal except Exception as e: logger.error(f"Failed to check certificate expiry: {e}") return True # Assume renewal needed if check fails def _renew_cert(self): """Renew certificate using zts-svccert""" try: # Check if zts-svccert is available if not shutil.which('zts-svccert'): raise RuntimeError("zts-svccert command not found. Make sure athenz_tools_jp is installed.") domain = f"flava-iam.{self.environment}.{self.project_name}" # Get configuration zts_endpoint = os.getenv('ZTS_ENDPOINT', DEFAULT_ZTS_ENDPOINT) # Use account_name for both service and key_version service = self.account_name key_version = self.account_name # Get POD_UID from environment variable with hostname as fallback pod_uid = os.getenv('POD_UID', socket.gethostname()) logger.info(f"Renewing certificate for domain={domain}, service={service}, pod_uid={pod_uid}") cmd = [ 'zts-svccert', '-zts', zts_endpoint, '-domain', domain, '-service', service, '-private-key', self.key_path, '-key-version', key_version, '-cert-file', self.cert_path, '-hdr', 'Yahoo-Principal-Auth', '-dns-domain', 'athenz.yahoo.co.jp', '-provider', 'sys.auth.zts', '-instance', pod_uid ] logger.debug(f"Executing: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True, check=True) logger.info("Certificate renewal completed successfully") logger.debug(f"zts-svccert output: {result.stdout}") except subprocess.CalledProcessError as e: logger.error(f"zts-svccert failed: {e}") logger.error(f"Error output: {e.stderr}") raise RuntimeError(f"Certificate renewal failed: {e}") from e except Exception as e: logger.error(f"Certificate renewal error: {e}") raise RuntimeError(f"Certificate renewal failed: {e}") from e def _validate_cert(self): """Validate that certificate is properly formatted and readable""" try: if not os.path.exists(self.cert_path): logger.error("Certificate file not found after renewal") return False cmd = ['openssl', 'x509', '-in', self.cert_path, '-noout', '-text'] result = subprocess.run(cmd, capture_output=True, text=True, check=True) if "Certificate:" in result.stdout: logger.debug("Certificate validation successful") return True else: logger.error("Certificate validation failed - invalid format") return False except Exception as e: logger.error(f"Certificate validation error: {e}") return False def _get_latest_cert_path(self): """Calculate latest certificate path: /tmp/athenz/{project}.{env}.{account_name}.cert.pem""" try: # Use pre-extracted environment and project from __init__ if not self.environment or not self.project_name: raise ValueError(f"Could not parse provider_domain pattern: {self.provider_domain}") # Build latest certificate path latest_cert_filename = f"{self.project_name}.{self.environment}.{self.account_name}.cert.pem" latest_cert_path = os.path.join("/tmp/athenz", latest_cert_filename) logger.debug(f"Generated latest cert path: {latest_cert_path}") return latest_cert_path except Exception as e: logger.error(f"Failed to generate latest certificate path: {e}") # Fallback to simple naming fallback_filename = f"{self.account_name}.cert.pem" fallback_path = os.path.join("/tmp/athenz", fallback_filename) logger.warning(f"Using fallback cert path: {fallback_path}") return fallback_path