Ep.05 FastAPI Dependency Injection, Middleware & Logging Guide

Views: 4

FastAPI handles cross-cutting concerns through a combination of Dependency Injection for local route logic and Middleware for global request/response processing. By utilizing the Depends() function, developers can inject reusable components like database sessions or authentication checks directly into endpoints. For global operations—such as structured JSON logging, performance monitoring, and CORS management—FastAPI’s BaseHTTPMiddleware allows you to intercept every request, facilitating the injection of unique request IDs and background tasks to ensure high observability and non-blocking performance in production environments.

🎓 What You’ll Learn

By the end of this tutorial, you’ll understand:

  • Advanced dependency injection patterns in FastAPI
  • Creating and using custom middleware
  • Implementing structured logging with context
  • Background tasks for async processing
  • Request/response lifecycle in FastAPI
  • Performance monitoring and request tracking

📖 Understanding Dependency Injection

What is Dependency Injection?

Simple Definition: Instead of creating objects yourself, you declare what you need and FastAPI provides it.

Real-world analogy:

  • Without DI: Going to kitchen, opening fridge, getting ingredients, cooking
  • With DI: Saying “I need pasta” and someone brings you a cooked plate

Why Use Dependency Injection?

BenefitExample
ReusabilitySame dependency in multiple endpoints
TestabilityEasy to mock dependencies in tests
Separation of ConcernsBusiness logic separate from infrastructure
Code OrganizationClear dependencies, no hidden coupling
Async SupportWorks seamlessly with async/await

🛠️ Step-by-Step Implementation

Step 1: Create Advanced Dependencies

Create app/dependencies.py:

"""
Application dependencies

Reusable dependencies for FastAPI endpoints
"""

from fastapi import Header, HTTPException, status, Query, Request
from typing import Optional, Annotated
from datetime import datetime
import time

from app.core.config import settings


# ============================================
# SIMPLE DEPENDENCIES
# ============================================

async def get_current_timestamp() -> datetime:
    """
    Get current timestamp
    
    Simple dependency that returns current time
    """
    return datetime.now()


async def verify_token(x_token: Annotated[str, Header()] = None) -> str:
    """
    Verify API token from header
    
    Args:
        x_token: Token from X-Token header
    
    Returns:
        Token if valid
    
    Raises:
        HTTPException: If token is invalid or missing
    """
    if x_token is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="X-Token header missing"
        )
    
    # In production, verify against database or JWT
    valid_tokens = ["secret-token-123", "admin-token-456"]
    
    if x_token not in valid_tokens:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token"
        )
    
    return x_token


async def verify_api_key(api_key: Annotated[str, Query()] = None) -> str:
    """
    Verify API key from query parameter
    
    Args:
        api_key: API key from query parameter
    
    Returns:
        API key if valid
    
    Raises:
        HTTPException: If API key is invalid
    """
    if api_key is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="API key required"
        )
    
    valid_keys = ["key-12345", "key-67890"]
    
    if api_key not in valid_keys:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid API key"
        )
    
    return api_key


# ============================================
# DEPENDENCY CLASSES
# ============================================

class CommonQueryParams:
    """
    Common query parameters for list endpoints
    
    Reusable pagination and filtering parameters
    """
    
    def __init__(
        self,
        skip: int = Query(0, ge=0, description="Number of records to skip"),
        limit: int = Query(10, ge=1, le=100, description="Max records to return"),
        sort_by: Optional[str] = Query(None, description="Field to sort by"),
        sort_order: str = Query("asc", regex="^(asc|desc)$", description="Sort order")
    ):
        self.skip = skip
        self.limit = limit
        self.sort_by = sort_by
        self.sort_order = sort_order


class RateLimiter:
    """
    Simple rate limiter dependency
    
    Limits requests per time window
    """
    
    def __init__(self, max_requests: int = 10, window_seconds: int = 60):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.requests = {}  # {client_id: [(timestamp, ...)]}
    
    async def __call__(self, request: Request):
        """
        Check rate limit for client
        
        Args:
            request: FastAPI request object
        
        Raises:
            HTTPException: If rate limit exceeded
        """
        # Use client IP as identifier (in production, use user ID)
        client_id = request.client.host
        current_time = time.time()
        
        # Clean old requests outside window
        if client_id in self.requests:
            self.requests[client_id] = [
                req_time for req_time in self.requests[client_id]
                if current_time - req_time < self.window_seconds
            ]
        else:
            self.requests[client_id] = []
        
        # Check if limit exceeded
        if len(self.requests[client_id]) >= self.max_requests:
            raise HTTPException(
                status_code=status.HTTP_429_TOO_MANY_REQUESTS,
                detail=f"Rate limit exceeded. Max {self.max_requests} requests per {self.window_seconds} seconds."
            )
        
        # Add current request
        self.requests[client_id].append(current_time)
        
        return True


# ============================================
# NESTED DEPENDENCIES
# ============================================

async def get_user_agent(user_agent: Annotated[str, Header()] = None) -> Optional[str]:
    """Get user agent from header"""
    return user_agent


async def get_request_id(x_request_id: Annotated[str, Header()] = None) -> str:
    """
    Get or generate request ID
    
    Used for request tracking across services
    """
    import uuid
    return x_request_id or str(uuid.uuid4())


class RequestContext:
    """
    Request context with multiple dependencies
    
    Aggregates multiple dependencies into a single object
    """
    
    def __init__(
        self,
        request: Request,
        request_id: str = None,
        user_agent: str = None,
        timestamp: datetime = None
    ):
        self.request = request
        self.request_id = request_id or "unknown"
        self.user_agent = user_agent or "unknown"
        self.timestamp = timestamp or datetime.now()
        self.client_ip = request.client.host if request.client else "unknown"
        self.method = request.method
        self.url = str(request.url)


async def get_request_context(
    request: Request,
    request_id: Annotated[str, Header(alias="X-Request-ID")] = None,
    user_agent: Annotated[str, Header(alias="User-Agent")] = None
) -> RequestContext:
    """
    Build request context from multiple dependencies
    
    This shows how to compose dependencies
    """
    import uuid
    return RequestContext(
        request=request,
        request_id=request_id or str(uuid.uuid4()),
        user_agent=user_agent,
        timestamp=datetime.now()
    )


# ============================================
# DEPENDENCY WITH CLEANUP (Generator Pattern)
# ============================================

async def get_db_session():
    """
    Database session dependency (example)
    
    This pattern is used for resources that need cleanup
    In real app, this would create and close database sessions
    """
    # Setup: Create session
    db = {"connected": True, "session_id": "session-123"}
    print("📂 Database session created")
    
    try:
        # Yield the resource
        yield db
    finally:
        # Cleanup: Close session
        print("📂 Database session closed")
        db["connected"] = False


# ============================================
# SUB-DEPENDENCIES
# ============================================

def get_settings():
    """Get application settings"""
    from app.core.config import get_settings
    return get_settings()


async def get_admin_user(token: str = Header(alias="X-Admin-Token")):
    """
    Verify admin token
    
    This depends on settings (sub-dependency)
    """
    settings = get_settings()
    
    # In production, verify against database
    if token != "admin-secret-token":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Admin access required"
        )
    
    return {"role": "admin", "token": token}


# ============================================
# CACHED DEPENDENCIES
# ============================================

from functools import lru_cache

@lru_cache()
def get_expensive_resource():
    """
    Expensive resource that should be cached
    
    @lru_cache ensures this is only computed once
    """
    print("💰 Computing expensive resource (this should only happen once)")
    return {
        "data": "expensive computation result",
        "computed_at": datetime.now()
    }

🔍 Key Dependency Patterns Explained:

  1. Simple Function: Returns a value
  2. Class with __init__: Parameters from request
  3. Class with __call__: Stateful dependencies (rate limiter)
  4. Generator with yield: Resources needing cleanup (database)
  5. Cached with @lru_cache: Expensive computations
  6. Nested Dependencies: Dependencies that depend on other dependencies

Step 2: Create Logging Configuration

Create app/utils/logger.py:

"""
Structured logging configuration

Provides context-aware logging throughout the application
"""

import logging
import sys
from typing import Any, Dict
from datetime import datetime
import json
from pythonjsonlogger import jsonlogger


# ============================================
# CUSTOM LOG FORMATTER
# ============================================

class CustomJsonFormatter(jsonlogger.JsonFormatter):
    """
    Custom JSON log formatter
    
    Adds standard fields to all log records
    """
    
    def add_fields(self, log_record: Dict, record: logging.LogRecord, message_dict: Dict):
        """Add custom fields to log record"""
        super().add_fields(log_record, record, message_dict)
        
        # Add timestamp
        log_record['timestamp'] = datetime.utcnow().isoformat()
        
        # Add log level
        log_record['level'] = record.levelname
        
        # Add logger name
        log_record['logger'] = record.name
        
        # Add module/function info
        log_record['module'] = record.module
        log_record['function'] = record.funcName
        log_record['line'] = record.lineno


# ============================================
# LOGGER SETUP
# ============================================

def setup_logger(name: str = "fastapi_ai_backend", level: int = logging.INFO) -> logging.Logger:
    """
    Setup structured logger
    
    Args:
        name: Logger name
        level: Logging level
    
    Returns:
        Configured logger
    """
    logger = logging.getLogger(name)
    logger.setLevel(level)
    
    # Remove existing handlers
    logger.handlers = []
    
    # Console handler with JSON formatting
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(level)
    
    # JSON formatter for structured logs
    formatter = CustomJsonFormatter(
        '%(timestamp)s %(level)s %(name)s %(message)s'
    )
    console_handler.setFormatter(formatter)
    
    logger.addHandler(console_handler)
    
    # Prevent propagation to root logger
    logger.propagate = False
    
    return logger


# ============================================
# CONTEXT LOGGER
# ============================================

class ContextLogger:
    """
    Logger with request context
    
    Automatically includes request ID, user info, etc. in all logs
    """
    
    def __init__(self, logger: logging.Logger, context: Dict[str, Any] = None):
        self.logger = logger
        self.context = context or {}
    
    def _log_with_context(self, level: int, message: str, **kwargs):
        """Log message with context"""
        # Merge context with kwargs
        log_data = {**self.context, **kwargs}
        
        # Add message
        log_data['message'] = message
        
        # Log as JSON
        self.logger.log(level, json.dumps(log_data))
    
    def debug(self, message: str, **kwargs):
        """Log debug message"""
        self._log_with_context(logging.DEBUG, message, **kwargs)
    
    def info(self, message: str, **kwargs):
        """Log info message"""
        self._log_with_context(logging.INFO, message, **kwargs)
    
    def warning(self, message: str, **kwargs):
        """Log warning message"""
        self._log_with_context(logging.WARNING, message, **kwargs)
    
    def error(self, message: str, **kwargs):
        """Log error message"""
        self._log_with_context(logging.ERROR, message, **kwargs)
    
    def critical(self, message: str, **kwargs):
        """Log critical message"""
        self._log_with_context(logging.CRITICAL, message, **kwargs)
    
    def with_context(self, **additional_context) -> 'ContextLogger':
        """Create new logger with additional context"""
        new_context = {**self.context, **additional_context}
        return ContextLogger(self.logger, new_context)


# ============================================
# GLOBAL LOGGER INSTANCE
# ============================================

# Create global logger
logger = setup_logger()

# Example usage:
# logger.info("Application started")
# 
# With context:
# context_logger = ContextLogger(logger, {"request_id": "123", "user_id": "456"})
# context_logger.info("User logged in")

Install JSON logging library:

pip install python-json-logger
pip freeze > requirements.txt

Step 3: Create Custom Middleware

Create app/middleware/logging_middleware.py:

"""
Logging middleware

Logs all requests and responses with timing information
"""

import time
import uuid
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp

from app.utils.logger import logger


class LoggingMiddleware(BaseHTTPMiddleware):
    """
    Middleware to log all requests and responses
    
    Adds request ID and timing information
    """
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        """
        Process request and log details
        
        Args:
            request: Incoming request
            call_next: Next middleware/endpoint
        
        Returns:
            Response from endpoint
        """
        # Generate request ID if not present
        request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
        
        # Start timer
        start_time = time.time()
        
        # Log incoming request
        logger.info(
            "Incoming request",
            extra={
                "request_id": request_id,
                "method": request.method,
                "url": str(request.url),
                "client_ip": request.client.host if request.client else None,
                "user_agent": request.headers.get("user-agent"),
            }
        )
        
        # Add request ID to request state (accessible in endpoints)
        request.state.request_id = request_id
        
        # Process request
        try:
            response = await call_next(request)
        except Exception as e:
            # Log error
            logger.error(
                "Request failed",
                extra={
                    "request_id": request_id,
                    "error": str(e),
                    "error_type": type(e).__name__
                }
            )
            raise
        
        # Calculate duration
        duration = time.time() - start_time
        
        # Log response
        logger.info(
            "Request completed",
            extra={
                "request_id": request_id,
                "status_code": response.status_code,
                "duration_ms": round(duration * 1000, 2)
            }
        )
        
        # Add headers to response
        response.headers["X-Request-ID"] = request_id
        response.headers["X-Process-Time"] = str(round(duration, 4))
        
        return response

Create app/middleware/cors_middleware.py:

"""
Custom CORS middleware

More control over CORS than the default middleware
"""

from typing import List
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response


class CustomCORSMiddleware(BaseHTTPMiddleware):
    """
    Custom CORS middleware with additional features
    """
    
    def __init__(
        self,
        app,
        allow_origins: List[str] = None,
        allow_methods: List[str] = None,
        allow_headers: List[str] = None,
        expose_headers: List[str] = None
    ):
        super().__init__(app)
        self.allow_origins = allow_origins or ["*"]
        self.allow_methods = allow_methods or ["*"]
        self.allow_headers = allow_headers or ["*"]
        self.expose_headers = expose_headers or []
    
    async def dispatch(self, request: Request, call_next):
        """Process request and add CORS headers"""
        
        # Handle preflight requests
        if request.method == "OPTIONS":
            response = Response()
        else:
            response = await call_next(request)
        
        # Add CORS headers
        origin = request.headers.get("origin")
        
        if origin:
            # Check if origin is allowed
            if "*" in self.allow_origins or origin in self.allow_origins:
                response.headers["Access-Control-Allow-Origin"] = origin
                response.headers["Access-Control-Allow-Credentials"] = "true"
        
        # Add other CORS headers
        response.headers["Access-Control-Allow-Methods"] = ", ".join(self.allow_methods)
        response.headers["Access-Control-Allow-Headers"] = ", ".join(self.allow_headers)
        
        if self.expose_headers:
            response.headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
        
        return response

Create app/middleware/performance_middleware.py:

"""
Performance monitoring middleware

Tracks slow requests and performance metrics
"""

import time
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

from app.utils.logger import logger


class PerformanceMiddleware(BaseHTTPMiddleware):
    """
    Monitor request performance
    
    Logs warnings for slow requests
    """
    
    def __init__(self, app, slow_request_threshold: float = 1.0):
        """
        Initialize performance middleware
        
        Args:
            app: ASGI application
            slow_request_threshold: Threshold in seconds for slow request warning
        """
        super().__init__(app)
        self.slow_request_threshold = slow_request_threshold
    
    async def dispatch(self, request: Request, call_next):
        """Monitor request performance"""
        start_time = time.time()
        
        response = await call_next(request)
        
        duration = time.time() - start_time
        
        # Log slow requests
        if duration > self.slow_request_threshold:
            logger.warning(
                "Slow request detected",
                extra={
                    "method": request.method,
                    "url": str(request.url),
                    "duration_seconds": round(duration, 2),
                    "threshold_seconds": self.slow_request_threshold
                }
            )
        
        return response

Step 4: Create Background Tasks Example

Create app/utils/background_tasks.py:

"""
Background task utilities

Functions that run asynchronously after response is sent
"""

import asyncio
from typing import Dict, Any
from datetime import datetime

from app.utils.logger import logger


async def send_welcome_email(email: str, username: str):
    """
    Send welcome email (simulated)
    
    In production, this would integrate with email service
    """
    logger.info(
        "Sending welcome email",
        extra={"email": email, "username": username}
    )
    
    # Simulate email sending delay
    await asyncio.sleep(2)
    
    logger.info(
        "Welcome email sent",
        extra={"email": email, "username": username}
    )


async def process_user_analytics(user_id: int, event: str, metadata: Dict[str, Any]):
    """
    Process user analytics event
    
    This runs in background without blocking the response
    """
    logger.info(
        "Processing analytics event",
        extra={
            "user_id": user_id,
            "event": event,
            "metadata": metadata
        }
    )
    
    # Simulate analytics processing
    await asyncio.sleep(1)
    
    logger.info(
        "Analytics event processed",
        extra={"user_id": user_id, "event": event}
    )


async def cleanup_old_data(days: int = 30):
    """
    Cleanup old data (example background task)
    """
    logger.info(f"Starting data cleanup (older than {days} days)")
    
    # Simulate cleanup process
    await asyncio.sleep(3)
    
    logger.info("Data cleanup completed")


def log_user_activity(user_id: int, action: str, details: Dict[str, Any] = None):
    """
    Log user activity (synchronous background task)
    
    This can be a non-async function too
    """
    logger.info(
        "User activity logged",
        extra={
            "user_id": user_id,
            "action": action,
            "details": details or {},
            "timestamp": datetime.now().isoformat()
        }
    )

Step 5: Create Endpoints Demonstrating Dependencies

Create app/api/v1/endpoints/dependencies_demo.py:

"""
Endpoints demonstrating dependency injection patterns
"""

from fastapi import APIRouter, Depends, BackgroundTasks, Request
from typing import Annotated
from datetime import datetime

from app.dependencies import (
    get_current_timestamp,
    verify_token,
    verify_api_key,
    CommonQueryParams,
    RateLimiter,
    get_request_context,
    RequestContext,
    get_db_session,
    get_expensive_resource
)
from app.models.common import MessageResponse
from app.utils.background_tasks import (
    send_welcome_email,
    process_user_analytics,
    log_user_activity
)

router = APIRouter(prefix="/dependencies", tags=["Dependencies Demo"])


# ============================================
# SIMPLE DEPENDENCY EXAMPLES
# ============================================

@router.get("/timestamp")
async def get_timestamp(
    timestamp: Annotated[datetime, Depends(get_current_timestamp)]
):
    """
    Example: Simple dependency that returns a value
    """
    return {
        "message": "Current timestamp from dependency",
        "timestamp": timestamp
    }


@router.get("/protected")
async def protected_endpoint(
    token: Annotated[str, Depends(verify_token)]
):
    """
    Example: Authentication dependency
    
    Try with header: X-Token: secret-token-123
    """
    return {
        "message": "Access granted",
        "token": token
    }


@router.get("/api-key-protected")
async def api_key_protected(
    api_key: Annotated[str, Depends(verify_api_key)]
):
    """
    Example: API key authentication
    
    Try with query parameter: ?api_key=key-12345
    """
    return {
        "message": "API access granted",
        "api_key": api_key
    }


# ============================================
# CLASS-BASED DEPENDENCY
# ============================================

@router.get("/items")
async def list_items(
    params: Annotated[CommonQueryParams, Depends()]
):
    """
    Example: Class-based dependency for common parameters
    
    Try: ?skip=0&limit=10&sort_by=name&sort_order=desc
    """
    return {
        "message": "Listing items with parameters",
        "skip": params.skip,
        "limit": params.limit,
        "sort_by": params.sort_by,
        "sort_order": params.sort_order,
        "items": [f"item_{i}" for i in range(params.skip, params.skip + params.limit)]
    }


# ============================================
# RATE LIMITER DEPENDENCY
# ============================================

# Create rate limiter instance (5 requests per 60 seconds)
rate_limiter = RateLimiter(max_requests=5, window_seconds=60)

@router.get("/rate-limited")
async def rate_limited_endpoint(
    rate_limit_check: Annotated[bool, Depends(rate_limiter)]
):
    """
    Example: Rate-limited endpoint
    
    Try making more than 5 requests in 60 seconds
    """
    return {
        "message": "Request allowed",
        "note": "Max 5 requests per minute"
    }


# ============================================
# REQUEST CONTEXT DEPENDENCY
# ============================================

@router.get("/context")
async def get_context_info(
    context: Annotated[RequestContext, Depends(get_request_context)]
):
    """
    Example: Request context with multiple dependencies
    
    Aggregates request information
    """
    return {
        "request_id": context.request_id,
        "user_agent": context.user_agent,
        "client_ip": context.client_ip,
        "method": context.method,
        "url": context.url,
        "timestamp": context.timestamp
    }


# ============================================
# GENERATOR DEPENDENCY (WITH CLEANUP)
# ============================================

@router.get("/database")
async def use_database(
    db: Annotated[dict, Depends(get_db_session)]
):
    """
    Example: Dependency with cleanup (generator pattern)
    
    Watch the logs to see session creation and cleanup
    """
    return {
        "message": "Database operation completed",
        "session_id": db.get("session_id"),
        "connected": db.get("connected")
    }


# ============================================
# CACHED DEPENDENCY
# ============================================

@router.get("/expensive")
async def use_expensive_resource(
    resource: Annotated[dict, Depends(get_expensive_resource)]
):
    """
    Example: Cached expensive dependency
    
    First call computes, subsequent calls use cache
    Watch logs - computation message appears only once
    """
    return {
        "message": "Using expensive resource",
        "data": resource["data"],
        "computed_at": resource["computed_at"]
    }


# ============================================
# BACKGROUND TASKS
# ============================================

@router.post("/signup")
async def signup_user(
    email: str,
    username: str,
    background_tasks: BackgroundTasks
):
    """
    Example: Background task (send email after response)
    
    The email "sending" happens after the response is sent
    """
    # Add background task
    background_tasks.add_task(send_welcome_email, email, username)
    background_tasks.add_task(
        process_user_analytics,
        user_id=123,
        event="user_signup",
        metadata={"email": email, "username": username}
    )
    
    # Response is sent immediately
    return {
        "message": "Signup successful",
        "email": email,
        "username": username,
        "note": "Welcome email will be sent in background"
    }


@router.post("/activity/{user_id}")
async def log_activity(
    user_id: int,
    action: str,
    background_tasks: BackgroundTasks
):
    """
    Example: Synchronous background task
    """
    # Add synchronous task
    background_tasks.add_task(
        log_user_activity,
        user_id=user_id,
        action=action,
        details={"timestamp": datetime.now()}
    )
    
    return {
        "message": "Activity logged",
        "user_id": user_id,
        "action": action
    }


# ============================================
# MULTIPLE DEPENDENCIES
# ============================================

@router.get("/complex")
async def complex_endpoint(
    token: Annotated[str, Depends(verify_token)],
    params: Annotated[CommonQueryParams, Depends()],
    context: Annotated[RequestContext, Depends(get_request_context)],
    db: Annotated[dict, Depends(get_db_session)]
):
    """
    Example: Multiple dependencies in one endpoint
    
    Combines authentication, parameters, context, and database
    """
    return {
        "message": "Complex endpoint with multiple dependencies",
        "authenticated": True,
        "token": token,
        "pagination": {
            "skip": params.skip,
            "limit": params.limit
        },
        "request_id": context.request_id,
        "database_session": db.get("session_id")
    }


# ============================================
# REQUEST STATE ACCESS
# ============================================

@router.get("/request-state")
async def access_request_state(request: Request):
    """
    Example: Accessing request state set by middleware
    
    The logging middleware sets request_id in request.state
    """
    request_id = getattr(request.state, "request_id", "not-set")
    
    return {
        "message": "Accessing request state",
        "request_id": request_id,
        "note": "Request ID was set by logging middleware"
    }

Step 6: Update Main Application with Middleware

Update app/main.py:

"""
FastAPI AI Backend - Main Application

Updated with middleware and advanced features
"""

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError

from app.core.config import settings
from app.api.v1.api import api_router
from app.core.exceptions import AppException
from app.core.error_handlers import (
    app_exception_handler,
    validation_exception_handler,
    generic_exception_handler
)
from app.middleware.logging_middleware import LoggingMiddleware
from app.middleware.performance_middleware import PerformanceMiddleware
from app.utils.logger import logger


def create_application() -> FastAPI:
    """
    Application factory pattern
    
    Creates and configures the FastAPI application
    """
    
    app = FastAPI(
        title=settings.APP_NAME,
        version=settings.APP_VERSION,
        description="""
        A production-ready FastAPI backend for AI applications.
        
        ## Features
        * Advanced dependency injection patterns
        * Custom middleware for logging and monitoring
        * Structured logging with context
        * Background task processing
        * Request/response lifecycle management
        * Rate limiting
        * Performance monitoring
        * Advanced data validation with Pydantic
        * Comprehensive error handling
        * User management with CRUD operations
        * RESTful API design
        
        ## Middleware Stack
        1. Performance monitoring (slow request detection)
        2. Request logging (with request ID tracking)
        3. CORS handling
        
        ## Coming Soon
        * AI model integration (Ollama)
        * Authentication & Authorization (JWT)
        * Database integration (SQLAlchemy)
        * WebSocket support for streaming
        * Caching layer (Redis)
        """,
        debug=settings.DEBUG,
        docs_url="/docs",
        redoc_url="/redoc",
        openapi_url="/openapi.json"
    )
    
    # ============================================
    # MIDDLEWARE (ORDER MATTERS!)
    # ============================================
    
    # 1. Performance monitoring (outermost)
    app.add_middleware(
        PerformanceMiddleware,
        slow_request_threshold=1.0  # Log if request takes > 1 second
    )
    
    # 2. Request logging
    app.add_middleware(LoggingMiddleware)
    
    # 3. CORS (built-in FastAPI middleware)
    app.add_middleware(
        CORSMiddleware,
        allow_origins=settings.allowed_origins_list,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
        expose_headers=["X-Request-ID", "X-Process-Time"]
    )
    
    # ============================================
    # EXCEPTION HANDLERS
    # ============================================
    
    app.add_exception_handler(AppException, app_exception_handler)
    app.add_exception_handler(RequestValidationError, validation_exception_handler)
    app.add_exception_handler(ValidationError, validation_exception_handler)
    app.add_exception_handler(Exception, generic_exception_handler)
    
    # ============================================
    # ROUTERS
    # ============================================
    
    app.include_router(
        api_router,
        prefix=settings.API_V1_PREFIX
    )
    
    return app


# Create the application instance
app = create_application()


# ============================================
# LIFECYCLE EVENTS
# ============================================

@app.on_event("startup")
async def startup_event():
    """Code to run when the application starts"""
    logger.info(
        "Application startup",
        extra={
            "app_name": settings.APP_NAME,
            "version": settings.APP_VERSION,
            "environment": settings.ENVIRONMENT,
            "debug": settings.DEBUG
        }
    )
    
    print(f"🚀 Starting {settings.APP_NAME} v{settings.APP_VERSION}")
    print(f"📝 Environment: {settings.ENVIRONMENT}")
    print(f"🔧 Debug mode: {settings.DEBUG}")
    print(f"📚 API Docs: http://{settings.HOST}:{settings.PORT}/docs")
    print(f"✅ Middleware configured: Logging, Performance, CORS")
    print(f"✅ Error handling configured")


@app.on_event("shutdown")
async def shutdown_event():
    """Code to run when the application shuts down"""
    logger.info(
        "Application shutdown",
        extra={
            "app_name": settings.APP_NAME
        }
    )
    print(f"👋 Shutting down {settings.APP_NAME}")

🔍 Middleware Order Explained:

Request Flow (Outer → Inner):
    ↓
1. PerformanceMiddleware (measures total time)
    ↓
2. LoggingMiddleware (logs request/response)
    ↓
3. CORSMiddleware (handles CORS)
    ↓
4. Your endpoint
    ↓
Response Flow (Inner → Outer):
    ↓
4. Endpoint returns response
    ↓
3. CORS headers added
    ↓
2. Response logged
    ↓
1. Performance measured
    ↓
Client receives response

Step 7: Update API Router

Update app/api/v1/api.py:

"""
API v1 router aggregator

Combines all v1 endpoint routers
"""

from fastapi import APIRouter
from app.api.v1.endpoints import users, health, users_advanced, dependencies_demo

# Create main v1 router
api_router = APIRouter()

# Include all endpoint routers
api_router.include_router(health.router)
api_router.include_router(users.router)
api_router.include_router(users_advanced.router)
api_router.include_router(dependencies_demo.router)  # New!

# Future routers:
# api_router.include_router(ai.router)
# api_router.include_router(chat.router)

Step 8: Create Comprehensive Test Script

Create test_dependencies_middleware.py:

"""
Test script for dependencies and middleware

Tests dependency injection patterns and middleware functionality
"""

import requests
import json
import time

BASE_URL = "http://127.0.0.1:8000/api/v1"


def print_test(title: str, response: requests.Response):
    """Helper to print test results"""
    print(f"\n{'='*70}")
    print(f"{title}")
    print(f"{'='*70}")
    print(f"Status: {response.status_code}")
    print(f"Headers:")
    for key in ['X-Request-ID', 'X-Process-Time', 'Access-Control-Allow-Origin']:
        if key in response.headers:
            print(f"  {key}: {response.headers[key]}")
    try:
        print(f"Body:\n{json.dumps(response.json(), indent=2, default=str)}")
    except:
        print(f"Body: {response.text}")


def test_timestamp_dependency():
    """Test simple dependency"""
    response = requests.get(f"{BASE_URL}/dependencies/timestamp")
    print_test("📅 TIMESTAMP DEPENDENCY", response)


def test_protected_endpoint():
    """Test authentication dependency"""
    # Without token
    response = requests.get(f"{BASE_URL}/dependencies/protected")
    print_test("🔒 PROTECTED (NO TOKEN)", response)
    
    # With valid token
    headers = {"X-Token": "secret-token-123"}
    response = requests.get(f"{BASE_URL}/dependencies/protected", headers=headers)
    print_test("🔓 PROTECTED (VALID TOKEN)", response)
    
    # With invalid token
    headers = {"X-Token": "invalid-token"}
    response = requests.get(f"{BASE_URL}/dependencies/protected", headers=headers)
    print_test("❌ PROTECTED (INVALID TOKEN)", response)


def test_api_key_dependency():
    """Test API key dependency"""
    response = requests.get(
        f"{BASE_URL}/dependencies/api-key-protected",
        params={"api_key": "key-12345"}
    )
    print_test("🔑 API KEY DEPENDENCY", response)


def test_common_query_params():
    """Test class-based dependency"""
    response = requests.get(
        f"{BASE_URL}/dependencies/items",
        params={
            "skip": 5,
            "limit": 3,
            "sort_by": "name",
            "sort_order": "desc"
        }
    )
    print_test("📊 COMMON QUERY PARAMS", response)


def test_rate_limiter():
    """Test rate limiter dependency"""
    print(f"\n{'='*70}")
    print("⏱️  RATE LIMITER TEST (5 requests/minute)")
    print(f"{'='*70}")
    
    for i in range(7):
        response = requests.get(f"{BASE_URL}/dependencies/rate-limited")
        status = "✅" if response.status_code == 200 else "❌"
        print(f"Request {i+1}: {status} Status {response.status_code}")
        if response.status_code == 429:
            print(f"  Rate limit message: {response.json()['detail']}")
        time.sleep(0.5)


def test_request_context():
    """Test request context dependency"""
    headers = {
        "X-Request-ID": "custom-request-id-123",
        "User-Agent": "Test-Client/1.0"
    }
    response = requests.get(
        f"{BASE_URL}/dependencies/context",
        headers=headers
    )
    print_test("🎯 REQUEST CONTEXT", response)


def test_database_dependency():
    """Test generator dependency with cleanup"""
    print(f"\n{'='*70}")
    print("💾 DATABASE DEPENDENCY (watch server logs for cleanup)")
    print(f"{'='*70}")
    response = requests.get(f"{BASE_URL}/dependencies/database")
    print(f"Status: {response.status_code}")
    print(f"Response: {json.dumps(response.json(), indent=2)}")
    print("⚠️  Check server logs for 'Database session created/closed' messages")


def test_cached_dependency():
    """Test cached expensive dependency"""
    print(f"\n{'='*70}")
    print("💰 CACHED DEPENDENCY (watch server logs)")
    print(f"{'='*70}")
    
    for i in range(3):
        response = requests.get(f"{BASE_URL}/dependencies/expensive")
        print(f"\nCall {i+1}:")
        print(f"  Status: {response.status_code}")
        print(f"  Data: {response.json()['data']}")
        
        if i == 0:
            print("  ⚠️  First call: should see 'Computing expensive resource' in logs")
        else:
            print("  ✅ Subsequent call: should use cache (no log message)")


def test_background_tasks():
    """Test background tasks"""
    response = requests.post(
        f"{BASE_URL}/dependencies/signup",
        params={
            "email": "test@example.com",
            "username": "testuser"
        }
    )
    print_test("📧 BACKGROUND TASKS (check logs)", response)
    print("⚠️  Check server logs for background task execution (takes ~3 seconds)")


def test_activity_logging():
    """Test activity logging background task"""
    response = requests.post(
        f"{BASE_URL}/dependencies/activity/123",
        params={"action": "login"}
    )
    print_test("📝 ACTIVITY LOGGING", response)


def test_multiple_dependencies():
    """Test endpoint with multiple dependencies"""
    headers = {
        "X-Token": "secret-token-123",
        "X-Request-ID": "multi-dep-test-123"
    }
    params = {
        "skip": 0,
        "limit": 5,
        "sort_by": "name"
    }
    response = requests.get(
        f"{BASE_URL}/dependencies/complex",
        headers=headers,
        params=params
    )
    print_test("🎭 MULTIPLE DEPENDENCIES", response)


def test_request_state():
    """Test accessing request state from middleware"""
    response = requests.get(f"{BASE_URL}/dependencies/request-state")
    print_test("🔍 REQUEST STATE (from middleware)", response)


def test_middleware_headers():
    """Test middleware-added headers"""
    print(f"\n{'='*70}")
    print("🔧 MIDDLEWARE HEADERS TEST")
    print(f"{'='*70}")
    
    response = requests.get(f"{BASE_URL}/health")
    
    print("\nHeaders added by middleware:")
    print(f"  X-Request-ID: {response.headers.get('X-Request-ID', 'Not found')}")
    print(f"  X-Process-Time: {response.headers.get('X-Process-Time', 'Not found')} seconds")
    print(f"  Access-Control-Allow-Origin: {response.headers.get('Access-Control-Allow-Origin', 'Not found')}")


def test_slow_request():
    """Test performance monitoring (slow request detection)"""
    print(f"\n{'='*70}")
    print("🐌 SLOW REQUEST TEST (check server logs)")
    print(f"{'='*70}")
    print("Note: This endpoint doesn't exist, but watch for slow request warning in logs")
    print("      (if it takes > 1 second)")
    
    # This might be slow due to 404 handling
    response = requests.get(f"{BASE_URL}/slow-endpoint-test")
    print(f"Status: {response.status_code}")
    print(f"Process time: {response.headers.get('X-Process-Time')} seconds")


def run_all_tests():
    """Run all tests"""
    print("\n" + "🧪"*35)
    print("DEPENDENCIES & MIDDLEWARE TEST SUITE")
    print("🧪"*35)
    
    print("\n" + "="*70)
    print("SIMPLE DEPENDENCIES")
    print("="*70)
    test_timestamp_dependency()
    
    print("\n" + "="*70)
    print("AUTHENTICATION DEPENDENCIES")
    print("="*70)
    test_protected_endpoint()
    test_api_key_dependency()
    
    print("\n" + "="*70)
    print("CLASS-BASED DEPENDENCIES")
    print("="*70)
    test_common_query_params()
    
    print("\n" + "="*70)
    print("ADVANCED DEPENDENCIES")
    print("="*70)
    test_request_context()
    test_database_dependency()
    test_cached_dependency()
    
    print("\n" + "="*70)
    print("BACKGROUND TASKS")
    print("="*70)
    test_background_tasks()
    test_activity_logging()
    
    print("\n" + "="*70)
    print("COMPLEX SCENARIOS")
    print("="*70)
    test_multiple_dependencies()
    test_request_state()
    
    print("\n" + "="*70)
    print("MIDDLEWARE TESTS")
    print("="*70)
    test_middleware_headers()
    
    print("\n" + "="*70)
    print("RATE LIMITING")
    print("="*70)
    test_rate_limiter()
    
    print("\n" + "✅"*35)
    print("ALL TESTS COMPLETED!")
    print("✅"*35)
    print("\n⚠️  Remember to check server logs for:")
    print("  - Request logging with IDs")
    print("  - Background task execution")
    print("  - Database session cleanup")
    print("  - Slow request warnings")
    print("  - Cached dependency computation\n")


if __name__ == "__main__":
    print("""
    ╔════════════════════════════════════════════════════════╗
    ║  Dependencies & Middleware Test Suite                 ║
    ║                                                        ║
    ║  Make sure your server is running:                    ║
    ║  python main.py                                       ║
    ║                                                        ║
    ║  Watch the SERVER LOGS while running these tests!     ║
    ╚════════════════════════════════════════════════════════╝
    """)
    
    try:
        response = requests.get(f"{BASE_URL}/health")
        if response.status_code == 200:
            run_all_tests()
        else:
            print("❌ Server returned unexpected status code")
    except requests.exceptions.ConnectionError:
        print("❌ ERROR: Cannot connect to server!")
        print("   Please start the server with: python main.py")
    except Exception as e:
        print(f"❌ ERROR: {e}")

📦 Installation Requirements

Don’t forget to install the JSON logging library:

pip install python-json-logger
pip freeze > requirements.txt

🎯 Implementation Checklist

New Files to Create:

app/
├── dependencies.py                          ✅ Central dependencies
├── middleware/
│   ├── __init__.py                         ✅
│   ├── logging_middleware.py               ✅
│   ├── cors_middleware.py                  ✅
│   └── performance_middleware.py           ✅
├── utils/
│   ├── __init__.py                         ✅
│   ├── logger.py                           ✅
│   └── background_tasks.py                 ✅
└── api/v1/endpoints/
    └── dependencies_demo.py                ✅

test_dependencies_middleware.py              ✅

Files to Update:

app/main.py                                  ✅ Add middleware
app/api/v1/api.py                           ✅ Add demo router

🧪 Running the Tests

Step 1: Start the Server

python main.py

Watch the console logs! You’ll see structured JSON logs for every request.

Step 2: Run Tests (in another terminal)

# Activate venv
source venv/bin/activate  # or venv\Scripts\activate on Windows

# Run tests
python test_dependencies_middleware.py

What You’ll See:

In the test terminal:

  • ✅ Successful requests with proper dependencies
  • ❌ Failed auth attempts
  • 🔄 Rate limiting in action
  • 📧 Background task confirmation

In the server logs:

  • 📝 Request logging with IDs
  • 💾 Database session lifecycle
  • 💰 Cached dependency (computed once)
  • 📧 Background tasks executing
  • ⏱️ Request timing information

📊 Understanding the Request Lifecycle

Complete Request Flow:

1. Client sends request
   ↓
2. PerformanceMiddleware (start timer)
   ↓
3. LoggingMiddleware (log incoming request, add request ID)
   ↓
4. CORSMiddleware (handle CORS headers)
   ↓
5. FastAPI routing (match endpoint)
   ↓
6. Dependency resolution
   ├── get_current_timestamp() → executed
   ├── verify_token() → executed
   ├── CommonQueryParams() → instantiated
   └── get_db_session() → setup phase
   ↓
7. Your endpoint function executes
   ↓
8. Dependency cleanup
   └── get_db_session() → cleanup phase
   ↓
9. Response creation
   ↓
10. Exception handlers (if error occurred)
   ↓
11. CORSMiddleware (add CORS headers)
   ↓
12. LoggingMiddleware (log response)
   ↓
13. PerformanceMiddleware (check timing, add headers)
   ↓
14. Response sent to client
   ↓
15. Background tasks execute (async)

🎓 Real-World Use Cases

1. Multi-Tenant SaaS Application

async def get_tenant(x_tenant_id: str = Header(...)):
    """Identify tenant from header"""
    return await db.get_tenant(x_tenant_id)

async def get_tenant_db(tenant: Tenant = Depends(get_tenant)):
    """Connect to tenant-specific database"""
    return Database(tenant.db_connection_string)

@app.get("/customers")
async def list_customers(
    tenant: Tenant = Depends(get_tenant),
    db: Database = Depends(get_tenant_db)
):
    # Automatically queries correct tenant database
    return await db.customers.list()

2. API Rate Limiting with Redis

python

class RedisRateLimiter:
    def __init__(self, redis_client, max_requests: int = 100):
        self.redis = redis_client
        self.max_requests = max_requests
    
    async def __call__(self, request: Request, user_id: int):
        key = f"rate_limit:{user_id}"
        current = await self.redis.incr(key)
        
        if current == 1:
            await self.redis.expire(key, 3600)  # 1 hour window
        
        if current > self.max_requests:
            raise HTTPException(429, "Rate limit exceeded")
        
        return True

3. Request Tracing Across Microservices

class TracingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
        request.state.trace_id = trace_id
        
        # Log with trace ID
        logger.info("Request", extra={"trace_id": trace_id})
        
        response = await call_next(request)
        
        # Pass trace ID to next service
        response.headers["X-Trace-ID"] = trace_id
        
        return response

What is the main benefit of Dependency Injection (DI) in FastAPI?

The primary benefit of DI is decoupling. Instead of hardcoding logic like database connections or user authentication inside your route, you “inject” them using Depends(). This makes your code modular, easier to test (by swapping real dependencies with mocks), and allows for better separation of concerns.
Middleware is global and runs for every single request and response entering or leaving your application. It is ideal for cross-cutting concerns like adding security headers, performance timing, or logging. Dependencies are local; you apply them specifically to certain routes or routers where they are needed (e.g., only on admin endpoints).
In production environments, logs are often aggregated by tools like ELK Stack or Datadog. Standard text logs are difficult for these tools to parse. Structured JSON logging converts every log entry into a searchable object containing metadata like request_id, user_id, and execution_time, allowing you to debug complex issues across thousands of requests instantly.
Use BackgroundTasks for operations that shouldn’t make the user wait for a response. Common examples include sending welcome emails, processing analytics, or generating PDF reports. These tasks run after the response has been sent to the client, ensuring a fast and smooth user experience.
The Application Factory pattern (creating the app inside a create_app() function) allows you to dynamically attach middleware and exception handlers based on the environment (e.g., adding more verbose logging in “Development” but strict CORS in “Production”). It centralizes your app’s configuration, making it more manageable as it grows.

🎉 What You’ve Accomplished!

Mastered Dependency Injection

  • Simple function dependencies
  • Class-based dependencies
  • Generator dependencies with cleanup
  • Cached dependencies
  • Nested sub-dependencies

Created Production Middleware

  • Request logging with unique IDs
  • Performance monitoring
  • Custom CORS handling
  • Request state management

Implemented Structured Logging

  • JSON-formatted logs
  • Context-aware logging
  • Request tracking
  • Debugging-friendly output

Async Background Tasks

  • Non-blocking email sending
  • Analytics processing
  • Cleanup tasks
  • Activity logging

Professional Code Organization

  • Reusable patterns
  • Testable components
  • Clear separation of concerns
  • Production-ready architecture

🚀 Coming Up Next

Blog Post 6: AI Integration with Ollama

We’ll finally integrate AI into our backend:

  • Installing and setting up Ollama locally
  • Creating an AI service abstraction layer
  • Implementing chat endpoints
  • Streaming responses in real-time
  • Model management and switching
  • Building a simple HTML chat interface

You now have a solid FastAPI foundation. Time to make it AI-powered! 🤖

Leave a Reply

Your email address will not be published. Required fields are marked *