Overview

This DAG automatically refreshes Docker ECR (Elastic Container Registry) authentication tokens in Apache Airflow. ECR tokens expire every 12 hours, so this DAG runs twice daily to ensure continuous access to your Docker registry.

Previously there was a Session object with which i can update the database. But now that method is forbidden and the only way is by creating an API user and use to update the connection.

What This DAG Does

The DAG performs three main tasks:

  1. Extract ECR Token: Uses AWS boto3 to get a fresh authorization token from ECR
  2. Update Docker Connection: Updates Airflow’s docker_default connection with the new token using JWT authentication
  3. Test Connection: Validates that the updated connection works properly

Prerequisites

AWS Setup

  • AWS account with ECR access
  • IAM user/role with ECR permissions:{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "ecr:GetAuthorizationToken" ], "Resource": "*" } ] }

Airflow Environment

  • Airflow 3.x
  • Docker provider package installed
  • AWS credentials configured (via IAM role, environment variables, or AWS credentials file)

Step-by-Step Setup

Configure Airflow Variables

Set the following Airflow Variables in the Admin UI or via CLI:

# Via Airflow CLI
airflow variables set ecr_aws_account "123456789012"  # Your AWS account ID
airflow variables set ecr_aws_region_name "us-east-1"  # Your ECR region

Or via Airflow UI:

  • Go to Admin → Variables
  • Add ecr_aws_account with your AWS account ID
  • Add ecr_aws_region_name with your ECR region

3. Create Airflow API Connection

This is required because Airflow do not allow access to the database via the session object. It is something new from Airflow 3.

Create a connection for the Airflow API:

Via Airflow UI:

  1. Go to Admin → Connections
  2. Click the “+” button to add a new connection
  3. Fill in the details:
    • Connection Idairflow-api
    • Connection TypeHTTP
    • Hostlocalhost (or your Airflow webserver host)
    • Schemahttp (or https if using SSL)
    • Port8080 (or your Airflow webserver port)
    • Login: Your Airflow username
    • Password: Your Airflow password

Via CLI:

airflow connections add airflow-api \
    --conn-type http \
    --conn-host localhost \
    --conn-schema http \
    --conn-port 8080 \
    --conn-login your_username \
    --conn-password your_password

4. Docker Default Connection

Create/update the Docker connection that will be updated:

Via Airflow UI:

  1. Go to Admin → Connections
  2. Click the “+” button
  3. Fill in:
    • Connection Iddocker_default
    • Connection TypeDocker
    • Host: Your ECR registry URL (e.g., 123456789012.dkr.ecr.us-east-1.amazonaws.com)
    • LoginAWS (this will be updated by the DAG)
    • Password: (leave empty, will be updated by the DAG)

Via CLI:

airflow connections add docker_default \
    --conn-type docker \
    --conn-host 123456789012.dkr.ecr.us-east-1.amazonaws.com \
    --conn-login AWS

Deploy the DAG

  1. Copy the DAG code to your Airflow DAGs folder
  2. The DAG will appear in the Airflow UI as refresh_docker_token_v4

7. Configure Queues and Pools

The DAG uses a systemqueue pool. Create it:

Via Airflow UI:

  1. Go to Admin → Pools
  2. Create a pool named systemqueue with appropriate slots (e.g., 5)

Via CLI:

airflow pools set systemqueue 5 "System maintenance tasks"

DAG Configuration

Schedule

  • Cron55 5,17 * * * (runs at 5:55 AM and 5:55 PM daily)
  • Timezone: UTC (adjust as needed)

Key Settings

  • Max Active Runs: 1 (prevents overlapping executions)
  • Catchup: False (doesn’t backfill missed runs)
  • Retries: 2 with 1-minute delay
  • Tags["airflow", "docker", "ecr"]
  • Verify Docker daemon is running
  • Check that /var/run/docker.sock is accessible
  • Ensure the ECR registry URL is correct
"""
DAG to refresh Docker ECR authentication token
Updates the docker_default connection with fresh ECR credentials using JWT authentication

You should not have your own: ~/.docker/config.json

"""

import base64
import logging
from datetime import datetime, timedelta
from typing import Any, Dict

import boto3
import requests
from airflow.decorators import dag, task
from airflow.hooks.base import BaseHook
from airflow.models.variable import Variable
from airflow.providers.docker.hooks.docker import DockerHook

logger = logging.getLogger("ecr_docker_token_refresh")

ecr_aws_account = Variable.get("ecr_aws_account")
ecr_aws_region_name = Variable.get("ecr_aws_region_name")

default_args = {
    "retry_delay": timedelta(minutes=1),
    "depends_on_past": False,
    "retries": 2,
    "email_on_failure": False,
    "email_on_retry": False,
    "queue": "systemqueue",
    "pool": "systemqueue",
}

connection_id = "docker_default"
airflow_api_connection_id = "airflow-api"


def get_jwt_token(endpoint_url: str, username: str, password: str) -> str:
    """Get JWT token from Airflow API"""
    auth_url = f"{endpoint_url}/auth/token"
    payload = {"username": username, "password": password}
    headers = {"Content-Type": "application/json"}
    logger.info(f"Requesting JWT token from {auth_url}")

    response = requests.post(auth_url, json=payload, headers=headers)
    response.raise_for_status()

    token_data = response.json()
    access_token = token_data.get("access_token")

    if not access_token:
        raise ValueError("No access_token found in response")

    logger.info("Successfully obtained JWT token")
    return access_token


def update_connection_password_with_jwt(
    endpoint_url: str, jwt_token: str, password: str
) -> bool:
    """Update connection password using JWT token with v2 bulk API"""
    url = f"{endpoint_url}/api/v2/connections"

    # First, get the current connection to preserve other fields
    get_url = f"{endpoint_url}/api/v2/connections/{connection_id}"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {jwt_token}",
    }

    logger.info(f"Getting current connection {connection_id}")

    try:
        get_response = requests.get(get_url, headers=headers)
        get_response.raise_for_status()
        current_connection = get_response.json()

        logger.info(f"Current connection retrieved successfully")

        # Prepare bulk update payload using v2 API
        payload = {
            "actions": [
                {
                    "action": "update",
                    "entities": [
                        {
                            "connection_id": connection_id,
                            "conn_type": current_connection.get("conn_type", "docker"),
                            "password": password,  # This is what we're updating
                        }
                    ],
                    "action_on_non_existence": "fail",
                }
            ]
        }

        logger.info(f"Updating connection {connection_id} at {url}")

        response = requests.patch(url, json=payload, headers=headers)
        response.raise_for_status()

        response_data = response.json()
        logger.info(f"Bulk update response: {response_data}")

        # Check if update was successful
        update_results = response_data.get("update", {})
        success_count = len(update_results.get("success", []))
        error_count = len(update_results.get("errors", []))

        if success_count > 0 and error_count == 0:
            logger.info("Connection password updated successfully")
            return True
        else:
            logger.error(
                f"Update failed - Success: {success_count}, Errors: {error_count}"
            )
            if error_count > 0:
                logger.error(f"Errors: {update_results.get('errors', [])}")
            return False

    except requests.exceptions.RequestException as e:
        logger.error(f"Failed to update connection: {e}")
        if hasattr(e, "response") and e.response is not None:
            logger.error(f"Response status: {e.response.status_code}")
            logger.error(f"Response text: {e.response.text}")
        raise


@dag(
    default_args=default_args,
    schedule="55 5,17 * * *",
    start_date=datetime.now() - timedelta(days=1),
    max_active_runs=1,
    catchup=False,
    tags=["airflow", "docker", "ecr"],
    dag_id="refresh_docker_token_v4",
    description="Refresh Docker ECR token using JWT authentication",
)
def refresh_docker_token():
    @task(priority_weight=5, pool="systemqueue")
    def extract_ecr_token() -> Dict[str, Any]:
        """Extract ECR authorization token using boto3"""
        logger.info("Starting ECR token extraction")

        try:
            logger.info(f"Connecting to ECR in region {ecr_aws_region_name}")
            ecr_client = boto3.client("ecr", region_name=ecr_aws_region_name)

            logger.info(f"Requesting authorization token for account {ecr_aws_account}")
            response = ecr_client.get_authorization_token(registryIds=[ecr_aws_account])

            auth_data = response["authorizationData"][0]
            token = auth_data["authorizationToken"]
            registry_url = auth_data["proxyEndpoint"]
            expires_at = auth_data["expiresAt"]

            logger.info("Successfully retrieved token")
            logger.info(f"Registry URL: {registry_url}")
            logger.info(f"Token expires at: {expires_at}")

            decoded_token = base64.b64decode(token).decode()
            username, password = decoded_token.split(":", 1)

            logger.info(f"Decoded username: {username}")

            return {
                "registry_url": registry_url,
                "username": username,
                "password": password,
                "expires_at": expires_at.isoformat(),
                "raw_token": token,
            }

        except Exception as e:
            logger.error(f"Failed to extract ECR token: {str(e)}")
            raise

    @task(priority_weight=5, pool="systemqueue")
    def update_docker_connection(token_data: Dict[str, Any]) -> str:
        """Update Docker connection using JWT authentication"""
        logger.info("Starting Docker connection update using JWT authentication")
        logger.info("Token data received from previous task")

        try:
            # Get Airflow API connection details
            logger.info(
                f"Retrieving Airflow API connection: {airflow_api_connection_id}"
            )
            api_connection = BaseHook.get_connection(airflow_api_connection_id)

            endpoint_url = f"{api_connection.schema}://{api_connection.host}"
            if api_connection.port:
                endpoint_url += f":{api_connection.port}"

            username = api_connection.login
            password = api_connection.password

            logger.info(f"Using endpoint: {username} @ {endpoint_url}")
            jwt_token = get_jwt_token(endpoint_url, username, password)

            success = update_connection_password_with_jwt(
                endpoint_url, jwt_token, token_data["password"]
            )

            if success:
                return "SUCCESS: Docker connection updated successfully using JWT authentication"
            else:
                raise Exception("Failed to update connection")

        except Exception as e:
            logger.error(f"Failed to update Docker connection: {str(e)}")
            raise

    @task(priority_weight=3, pool="systemqueue")
    def test_docker_connection() -> str:
        """Test the updated Docker connection"""
        logger.info("Testing DockerHook...")

        try:
            # First get the connection details to debug
            connection = BaseHook.get_connection("docker_default")

            docker_hook = DockerHook(
                docker_conn_id="docker_default", base_url="unix://var/run/docker.sock"
            )
            logger.info("DockerHook created successfully")

            # Try to get docker client (this will test the connection more thoroughly)
            docker_client = docker_hook.get_conn()
            logger.info("Docker client connection established", docker_client.version())
            return "SUCCESS: Docker connection tested and working with DockerHook"

        except Exception as client_error:
            logger.error(f"Docker client test failed: {client_error}")

            # Show connection properties on failure
            try:
                connection = BaseHook.get_connection("docker_default")
 
            except Exception as conn_error:
                logger.error(f"Could not retrieve connection properties: {conn_error}")

            return f"FAILED: Docker connection test failed: {client_error}"

    # Task flow
    token_data = extract_ecr_token()
    update_result = update_docker_connection(token_data)
    test_result = test_docker_connection()

    token_data >> update_result >> test_result


refresh_docker_token_dag = refresh_docker_token()