Views: 4
FastAPI handles cross-cutting concerns through a combination of Dependency Injection for local route logic and Middleware for global request/response processing. By utilizing the Depends() function, developers can inject reusable components like database sessions or authentication checks directly into endpoints. For global operations—such as structured JSON logging, performance monitoring, and CORS management—FastAPI’s BaseHTTPMiddleware allows you to intercept every request, facilitating the injection of unique request IDs and background tasks to ensure high observability and non-blocking performance in production environments.
🎓 What You’ll Learn
By the end of this tutorial, you’ll understand:
- Advanced dependency injection patterns in FastAPI
- Creating and using custom middleware
- Implementing structured logging with context
- Background tasks for async processing
- Request/response lifecycle in FastAPI
- Performance monitoring and request tracking
📖 Understanding Dependency Injection
What is Dependency Injection?
Simple Definition: Instead of creating objects yourself, you declare what you need and FastAPI provides it.
Real-world analogy:
- ❌ Without DI: Going to kitchen, opening fridge, getting ingredients, cooking
- ✅ With DI: Saying “I need pasta” and someone brings you a cooked plate
Why Use Dependency Injection?
| Benefit | Example |
|---|---|
| Reusability | Same dependency in multiple endpoints |
| Testability | Easy to mock dependencies in tests |
| Separation of Concerns | Business logic separate from infrastructure |
| Code Organization | Clear dependencies, no hidden coupling |
| Async Support | Works seamlessly with async/await |
🛠️ Step-by-Step Implementation
Step 1: Create Advanced Dependencies
Create app/dependencies.py:
"""
Application dependencies
Reusable dependencies for FastAPI endpoints
"""
from fastapi import Header, HTTPException, status, Query, Request
from typing import Optional, Annotated
from datetime import datetime
import time
from app.core.config import settings
# ============================================
# SIMPLE DEPENDENCIES
# ============================================
async def get_current_timestamp() -> datetime:
"""
Get current timestamp
Simple dependency that returns current time
"""
return datetime.now()
async def verify_token(x_token: Annotated[str, Header()] = None) -> str:
"""
Verify API token from header
Args:
x_token: Token from X-Token header
Returns:
Token if valid
Raises:
HTTPException: If token is invalid or missing
"""
if x_token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="X-Token header missing"
)
# In production, verify against database or JWT
valid_tokens = ["secret-token-123", "admin-token-456"]
if x_token not in valid_tokens:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
return x_token
async def verify_api_key(api_key: Annotated[str, Query()] = None) -> str:
"""
Verify API key from query parameter
Args:
api_key: API key from query parameter
Returns:
API key if valid
Raises:
HTTPException: If API key is invalid
"""
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key required"
)
valid_keys = ["key-12345", "key-67890"]
if api_key not in valid_keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
return api_key
# ============================================
# DEPENDENCY CLASSES
# ============================================
class CommonQueryParams:
"""
Common query parameters for list endpoints
Reusable pagination and filtering parameters
"""
def __init__(
self,
skip: int = Query(0, ge=0, description="Number of records to skip"),
limit: int = Query(10, ge=1, le=100, description="Max records to return"),
sort_by: Optional[str] = Query(None, description="Field to sort by"),
sort_order: str = Query("asc", regex="^(asc|desc)$", description="Sort order")
):
self.skip = skip
self.limit = limit
self.sort_by = sort_by
self.sort_order = sort_order
class RateLimiter:
"""
Simple rate limiter dependency
Limits requests per time window
"""
def __init__(self, max_requests: int = 10, window_seconds: int = 60):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = {} # {client_id: [(timestamp, ...)]}
async def __call__(self, request: Request):
"""
Check rate limit for client
Args:
request: FastAPI request object
Raises:
HTTPException: If rate limit exceeded
"""
# Use client IP as identifier (in production, use user ID)
client_id = request.client.host
current_time = time.time()
# Clean old requests outside window
if client_id in self.requests:
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if current_time - req_time < self.window_seconds
]
else:
self.requests[client_id] = []
# Check if limit exceeded
if len(self.requests[client_id]) >= self.max_requests:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Max {self.max_requests} requests per {self.window_seconds} seconds."
)
# Add current request
self.requests[client_id].append(current_time)
return True
# ============================================
# NESTED DEPENDENCIES
# ============================================
async def get_user_agent(user_agent: Annotated[str, Header()] = None) -> Optional[str]:
"""Get user agent from header"""
return user_agent
async def get_request_id(x_request_id: Annotated[str, Header()] = None) -> str:
"""
Get or generate request ID
Used for request tracking across services
"""
import uuid
return x_request_id or str(uuid.uuid4())
class RequestContext:
"""
Request context with multiple dependencies
Aggregates multiple dependencies into a single object
"""
def __init__(
self,
request: Request,
request_id: str = None,
user_agent: str = None,
timestamp: datetime = None
):
self.request = request
self.request_id = request_id or "unknown"
self.user_agent = user_agent or "unknown"
self.timestamp = timestamp or datetime.now()
self.client_ip = request.client.host if request.client else "unknown"
self.method = request.method
self.url = str(request.url)
async def get_request_context(
request: Request,
request_id: Annotated[str, Header(alias="X-Request-ID")] = None,
user_agent: Annotated[str, Header(alias="User-Agent")] = None
) -> RequestContext:
"""
Build request context from multiple dependencies
This shows how to compose dependencies
"""
import uuid
return RequestContext(
request=request,
request_id=request_id or str(uuid.uuid4()),
user_agent=user_agent,
timestamp=datetime.now()
)
# ============================================
# DEPENDENCY WITH CLEANUP (Generator Pattern)
# ============================================
async def get_db_session():
"""
Database session dependency (example)
This pattern is used for resources that need cleanup
In real app, this would create and close database sessions
"""
# Setup: Create session
db = {"connected": True, "session_id": "session-123"}
print("📂 Database session created")
try:
# Yield the resource
yield db
finally:
# Cleanup: Close session
print("📂 Database session closed")
db["connected"] = False
# ============================================
# SUB-DEPENDENCIES
# ============================================
def get_settings():
"""Get application settings"""
from app.core.config import get_settings
return get_settings()
async def get_admin_user(token: str = Header(alias="X-Admin-Token")):
"""
Verify admin token
This depends on settings (sub-dependency)
"""
settings = get_settings()
# In production, verify against database
if token != "admin-secret-token":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required"
)
return {"role": "admin", "token": token}
# ============================================
# CACHED DEPENDENCIES
# ============================================
from functools import lru_cache
@lru_cache()
def get_expensive_resource():
"""
Expensive resource that should be cached
@lru_cache ensures this is only computed once
"""
print("💰 Computing expensive resource (this should only happen once)")
return {
"data": "expensive computation result",
"computed_at": datetime.now()
}
🔍 Key Dependency Patterns Explained:
- Simple Function: Returns a value
- Class with
__init__: Parameters from request - Class with
__call__: Stateful dependencies (rate limiter) - Generator with
yield: Resources needing cleanup (database) - Cached with
@lru_cache: Expensive computations - Nested Dependencies: Dependencies that depend on other dependencies
Step 2: Create Logging Configuration
Create app/utils/logger.py:
"""
Structured logging configuration
Provides context-aware logging throughout the application
"""
import logging
import sys
from typing import Any, Dict
from datetime import datetime
import json
from pythonjsonlogger import jsonlogger
# ============================================
# CUSTOM LOG FORMATTER
# ============================================
class CustomJsonFormatter(jsonlogger.JsonFormatter):
"""
Custom JSON log formatter
Adds standard fields to all log records
"""
def add_fields(self, log_record: Dict, record: logging.LogRecord, message_dict: Dict):
"""Add custom fields to log record"""
super().add_fields(log_record, record, message_dict)
# Add timestamp
log_record['timestamp'] = datetime.utcnow().isoformat()
# Add log level
log_record['level'] = record.levelname
# Add logger name
log_record['logger'] = record.name
# Add module/function info
log_record['module'] = record.module
log_record['function'] = record.funcName
log_record['line'] = record.lineno
# ============================================
# LOGGER SETUP
# ============================================
def setup_logger(name: str = "fastapi_ai_backend", level: int = logging.INFO) -> logging.Logger:
"""
Setup structured logger
Args:
name: Logger name
level: Logging level
Returns:
Configured logger
"""
logger = logging.getLogger(name)
logger.setLevel(level)
# Remove existing handlers
logger.handlers = []
# Console handler with JSON formatting
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(level)
# JSON formatter for structured logs
formatter = CustomJsonFormatter(
'%(timestamp)s %(level)s %(name)s %(message)s'
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Prevent propagation to root logger
logger.propagate = False
return logger
# ============================================
# CONTEXT LOGGER
# ============================================
class ContextLogger:
"""
Logger with request context
Automatically includes request ID, user info, etc. in all logs
"""
def __init__(self, logger: logging.Logger, context: Dict[str, Any] = None):
self.logger = logger
self.context = context or {}
def _log_with_context(self, level: int, message: str, **kwargs):
"""Log message with context"""
# Merge context with kwargs
log_data = {**self.context, **kwargs}
# Add message
log_data['message'] = message
# Log as JSON
self.logger.log(level, json.dumps(log_data))
def debug(self, message: str, **kwargs):
"""Log debug message"""
self._log_with_context(logging.DEBUG, message, **kwargs)
def info(self, message: str, **kwargs):
"""Log info message"""
self._log_with_context(logging.INFO, message, **kwargs)
def warning(self, message: str, **kwargs):
"""Log warning message"""
self._log_with_context(logging.WARNING, message, **kwargs)
def error(self, message: str, **kwargs):
"""Log error message"""
self._log_with_context(logging.ERROR, message, **kwargs)
def critical(self, message: str, **kwargs):
"""Log critical message"""
self._log_with_context(logging.CRITICAL, message, **kwargs)
def with_context(self, **additional_context) -> 'ContextLogger':
"""Create new logger with additional context"""
new_context = {**self.context, **additional_context}
return ContextLogger(self.logger, new_context)
# ============================================
# GLOBAL LOGGER INSTANCE
# ============================================
# Create global logger
logger = setup_logger()
# Example usage:
# logger.info("Application started")
#
# With context:
# context_logger = ContextLogger(logger, {"request_id": "123", "user_id": "456"})
# context_logger.info("User logged in")
Install JSON logging library:
pip install python-json-logger
pip freeze > requirements.txt
Step 3: Create Custom Middleware
Create app/middleware/logging_middleware.py:
"""
Logging middleware
Logs all requests and responses with timing information
"""
import time
import uuid
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.utils.logger import logger
class LoggingMiddleware(BaseHTTPMiddleware):
"""
Middleware to log all requests and responses
Adds request ID and timing information
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request and log details
Args:
request: Incoming request
call_next: Next middleware/endpoint
Returns:
Response from endpoint
"""
# Generate request ID if not present
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
# Start timer
start_time = time.time()
# Log incoming request
logger.info(
"Incoming request",
extra={
"request_id": request_id,
"method": request.method,
"url": str(request.url),
"client_ip": request.client.host if request.client else None,
"user_agent": request.headers.get("user-agent"),
}
)
# Add request ID to request state (accessible in endpoints)
request.state.request_id = request_id
# Process request
try:
response = await call_next(request)
except Exception as e:
# Log error
logger.error(
"Request failed",
extra={
"request_id": request_id,
"error": str(e),
"error_type": type(e).__name__
}
)
raise
# Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
"Request completed",
extra={
"request_id": request_id,
"status_code": response.status_code,
"duration_ms": round(duration * 1000, 2)
}
)
# Add headers to response
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(round(duration, 4))
return response
Create app/middleware/cors_middleware.py:
"""
Custom CORS middleware
More control over CORS than the default middleware
"""
from typing import List
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
class CustomCORSMiddleware(BaseHTTPMiddleware):
"""
Custom CORS middleware with additional features
"""
def __init__(
self,
app,
allow_origins: List[str] = None,
allow_methods: List[str] = None,
allow_headers: List[str] = None,
expose_headers: List[str] = None
):
super().__init__(app)
self.allow_origins = allow_origins or ["*"]
self.allow_methods = allow_methods or ["*"]
self.allow_headers = allow_headers or ["*"]
self.expose_headers = expose_headers or []
async def dispatch(self, request: Request, call_next):
"""Process request and add CORS headers"""
# Handle preflight requests
if request.method == "OPTIONS":
response = Response()
else:
response = await call_next(request)
# Add CORS headers
origin = request.headers.get("origin")
if origin:
# Check if origin is allowed
if "*" in self.allow_origins or origin in self.allow_origins:
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
# Add other CORS headers
response.headers["Access-Control-Allow-Methods"] = ", ".join(self.allow_methods)
response.headers["Access-Control-Allow-Headers"] = ", ".join(self.allow_headers)
if self.expose_headers:
response.headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
return response
Create app/middleware/performance_middleware.py:
"""
Performance monitoring middleware
Tracks slow requests and performance metrics
"""
import time
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.utils.logger import logger
class PerformanceMiddleware(BaseHTTPMiddleware):
"""
Monitor request performance
Logs warnings for slow requests
"""
def __init__(self, app, slow_request_threshold: float = 1.0):
"""
Initialize performance middleware
Args:
app: ASGI application
slow_request_threshold: Threshold in seconds for slow request warning
"""
super().__init__(app)
self.slow_request_threshold = slow_request_threshold
async def dispatch(self, request: Request, call_next):
"""Monitor request performance"""
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
# Log slow requests
if duration > self.slow_request_threshold:
logger.warning(
"Slow request detected",
extra={
"method": request.method,
"url": str(request.url),
"duration_seconds": round(duration, 2),
"threshold_seconds": self.slow_request_threshold
}
)
return response
Step 4: Create Background Tasks Example
Create app/utils/background_tasks.py:
"""
Background task utilities
Functions that run asynchronously after response is sent
"""
import asyncio
from typing import Dict, Any
from datetime import datetime
from app.utils.logger import logger
async def send_welcome_email(email: str, username: str):
"""
Send welcome email (simulated)
In production, this would integrate with email service
"""
logger.info(
"Sending welcome email",
extra={"email": email, "username": username}
)
# Simulate email sending delay
await asyncio.sleep(2)
logger.info(
"Welcome email sent",
extra={"email": email, "username": username}
)
async def process_user_analytics(user_id: int, event: str, metadata: Dict[str, Any]):
"""
Process user analytics event
This runs in background without blocking the response
"""
logger.info(
"Processing analytics event",
extra={
"user_id": user_id,
"event": event,
"metadata": metadata
}
)
# Simulate analytics processing
await asyncio.sleep(1)
logger.info(
"Analytics event processed",
extra={"user_id": user_id, "event": event}
)
async def cleanup_old_data(days: int = 30):
"""
Cleanup old data (example background task)
"""
logger.info(f"Starting data cleanup (older than {days} days)")
# Simulate cleanup process
await asyncio.sleep(3)
logger.info("Data cleanup completed")
def log_user_activity(user_id: int, action: str, details: Dict[str, Any] = None):
"""
Log user activity (synchronous background task)
This can be a non-async function too
"""
logger.info(
"User activity logged",
extra={
"user_id": user_id,
"action": action,
"details": details or {},
"timestamp": datetime.now().isoformat()
}
)
Step 5: Create Endpoints Demonstrating Dependencies
Create app/api/v1/endpoints/dependencies_demo.py:
"""
Endpoints demonstrating dependency injection patterns
"""
from fastapi import APIRouter, Depends, BackgroundTasks, Request
from typing import Annotated
from datetime import datetime
from app.dependencies import (
get_current_timestamp,
verify_token,
verify_api_key,
CommonQueryParams,
RateLimiter,
get_request_context,
RequestContext,
get_db_session,
get_expensive_resource
)
from app.models.common import MessageResponse
from app.utils.background_tasks import (
send_welcome_email,
process_user_analytics,
log_user_activity
)
router = APIRouter(prefix="/dependencies", tags=["Dependencies Demo"])
# ============================================
# SIMPLE DEPENDENCY EXAMPLES
# ============================================
@router.get("/timestamp")
async def get_timestamp(
timestamp: Annotated[datetime, Depends(get_current_timestamp)]
):
"""
Example: Simple dependency that returns a value
"""
return {
"message": "Current timestamp from dependency",
"timestamp": timestamp
}
@router.get("/protected")
async def protected_endpoint(
token: Annotated[str, Depends(verify_token)]
):
"""
Example: Authentication dependency
Try with header: X-Token: secret-token-123
"""
return {
"message": "Access granted",
"token": token
}
@router.get("/api-key-protected")
async def api_key_protected(
api_key: Annotated[str, Depends(verify_api_key)]
):
"""
Example: API key authentication
Try with query parameter: ?api_key=key-12345
"""
return {
"message": "API access granted",
"api_key": api_key
}
# ============================================
# CLASS-BASED DEPENDENCY
# ============================================
@router.get("/items")
async def list_items(
params: Annotated[CommonQueryParams, Depends()]
):
"""
Example: Class-based dependency for common parameters
Try: ?skip=0&limit=10&sort_by=name&sort_order=desc
"""
return {
"message": "Listing items with parameters",
"skip": params.skip,
"limit": params.limit,
"sort_by": params.sort_by,
"sort_order": params.sort_order,
"items": [f"item_{i}" for i in range(params.skip, params.skip + params.limit)]
}
# ============================================
# RATE LIMITER DEPENDENCY
# ============================================
# Create rate limiter instance (5 requests per 60 seconds)
rate_limiter = RateLimiter(max_requests=5, window_seconds=60)
@router.get("/rate-limited")
async def rate_limited_endpoint(
rate_limit_check: Annotated[bool, Depends(rate_limiter)]
):
"""
Example: Rate-limited endpoint
Try making more than 5 requests in 60 seconds
"""
return {
"message": "Request allowed",
"note": "Max 5 requests per minute"
}
# ============================================
# REQUEST CONTEXT DEPENDENCY
# ============================================
@router.get("/context")
async def get_context_info(
context: Annotated[RequestContext, Depends(get_request_context)]
):
"""
Example: Request context with multiple dependencies
Aggregates request information
"""
return {
"request_id": context.request_id,
"user_agent": context.user_agent,
"client_ip": context.client_ip,
"method": context.method,
"url": context.url,
"timestamp": context.timestamp
}
# ============================================
# GENERATOR DEPENDENCY (WITH CLEANUP)
# ============================================
@router.get("/database")
async def use_database(
db: Annotated[dict, Depends(get_db_session)]
):
"""
Example: Dependency with cleanup (generator pattern)
Watch the logs to see session creation and cleanup
"""
return {
"message": "Database operation completed",
"session_id": db.get("session_id"),
"connected": db.get("connected")
}
# ============================================
# CACHED DEPENDENCY
# ============================================
@router.get("/expensive")
async def use_expensive_resource(
resource: Annotated[dict, Depends(get_expensive_resource)]
):
"""
Example: Cached expensive dependency
First call computes, subsequent calls use cache
Watch logs - computation message appears only once
"""
return {
"message": "Using expensive resource",
"data": resource["data"],
"computed_at": resource["computed_at"]
}
# ============================================
# BACKGROUND TASKS
# ============================================
@router.post("/signup")
async def signup_user(
email: str,
username: str,
background_tasks: BackgroundTasks
):
"""
Example: Background task (send email after response)
The email "sending" happens after the response is sent
"""
# Add background task
background_tasks.add_task(send_welcome_email, email, username)
background_tasks.add_task(
process_user_analytics,
user_id=123,
event="user_signup",
metadata={"email": email, "username": username}
)
# Response is sent immediately
return {
"message": "Signup successful",
"email": email,
"username": username,
"note": "Welcome email will be sent in background"
}
@router.post("/activity/{user_id}")
async def log_activity(
user_id: int,
action: str,
background_tasks: BackgroundTasks
):
"""
Example: Synchronous background task
"""
# Add synchronous task
background_tasks.add_task(
log_user_activity,
user_id=user_id,
action=action,
details={"timestamp": datetime.now()}
)
return {
"message": "Activity logged",
"user_id": user_id,
"action": action
}
# ============================================
# MULTIPLE DEPENDENCIES
# ============================================
@router.get("/complex")
async def complex_endpoint(
token: Annotated[str, Depends(verify_token)],
params: Annotated[CommonQueryParams, Depends()],
context: Annotated[RequestContext, Depends(get_request_context)],
db: Annotated[dict, Depends(get_db_session)]
):
"""
Example: Multiple dependencies in one endpoint
Combines authentication, parameters, context, and database
"""
return {
"message": "Complex endpoint with multiple dependencies",
"authenticated": True,
"token": token,
"pagination": {
"skip": params.skip,
"limit": params.limit
},
"request_id": context.request_id,
"database_session": db.get("session_id")
}
# ============================================
# REQUEST STATE ACCESS
# ============================================
@router.get("/request-state")
async def access_request_state(request: Request):
"""
Example: Accessing request state set by middleware
The logging middleware sets request_id in request.state
"""
request_id = getattr(request.state, "request_id", "not-set")
return {
"message": "Accessing request state",
"request_id": request_id,
"note": "Request ID was set by logging middleware"
}
Step 6: Update Main Application with Middleware
Update app/main.py:
"""
FastAPI AI Backend - Main Application
Updated with middleware and advanced features
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from app.core.config import settings
from app.api.v1.api import api_router
from app.core.exceptions import AppException
from app.core.error_handlers import (
app_exception_handler,
validation_exception_handler,
generic_exception_handler
)
from app.middleware.logging_middleware import LoggingMiddleware
from app.middleware.performance_middleware import PerformanceMiddleware
from app.utils.logger import logger
def create_application() -> FastAPI:
"""
Application factory pattern
Creates and configures the FastAPI application
"""
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="""
A production-ready FastAPI backend for AI applications.
## Features
* Advanced dependency injection patterns
* Custom middleware for logging and monitoring
* Structured logging with context
* Background task processing
* Request/response lifecycle management
* Rate limiting
* Performance monitoring
* Advanced data validation with Pydantic
* Comprehensive error handling
* User management with CRUD operations
* RESTful API design
## Middleware Stack
1. Performance monitoring (slow request detection)
2. Request logging (with request ID tracking)
3. CORS handling
## Coming Soon
* AI model integration (Ollama)
* Authentication & Authorization (JWT)
* Database integration (SQLAlchemy)
* WebSocket support for streaming
* Caching layer (Redis)
""",
debug=settings.DEBUG,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# ============================================
# MIDDLEWARE (ORDER MATTERS!)
# ============================================
# 1. Performance monitoring (outermost)
app.add_middleware(
PerformanceMiddleware,
slow_request_threshold=1.0 # Log if request takes > 1 second
)
# 2. Request logging
app.add_middleware(LoggingMiddleware)
# 3. CORS (built-in FastAPI middleware)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["X-Request-ID", "X-Process-Time"]
)
# ============================================
# EXCEPTION HANDLERS
# ============================================
app.add_exception_handler(AppException, app_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(ValidationError, validation_exception_handler)
app.add_exception_handler(Exception, generic_exception_handler)
# ============================================
# ROUTERS
# ============================================
app.include_router(
api_router,
prefix=settings.API_V1_PREFIX
)
return app
# Create the application instance
app = create_application()
# ============================================
# LIFECYCLE EVENTS
# ============================================
@app.on_event("startup")
async def startup_event():
"""Code to run when the application starts"""
logger.info(
"Application startup",
extra={
"app_name": settings.APP_NAME,
"version": settings.APP_VERSION,
"environment": settings.ENVIRONMENT,
"debug": settings.DEBUG
}
)
print(f"🚀 Starting {settings.APP_NAME} v{settings.APP_VERSION}")
print(f"📝 Environment: {settings.ENVIRONMENT}")
print(f"🔧 Debug mode: {settings.DEBUG}")
print(f"📚 API Docs: http://{settings.HOST}:{settings.PORT}/docs")
print(f"✅ Middleware configured: Logging, Performance, CORS")
print(f"✅ Error handling configured")
@app.on_event("shutdown")
async def shutdown_event():
"""Code to run when the application shuts down"""
logger.info(
"Application shutdown",
extra={
"app_name": settings.APP_NAME
}
)
print(f"👋 Shutting down {settings.APP_NAME}")
🔍 Middleware Order Explained:
Request Flow (Outer → Inner):
↓
1. PerformanceMiddleware (measures total time)
↓
2. LoggingMiddleware (logs request/response)
↓
3. CORSMiddleware (handles CORS)
↓
4. Your endpoint
↓
Response Flow (Inner → Outer):
↓
4. Endpoint returns response
↓
3. CORS headers added
↓
2. Response logged
↓
1. Performance measured
↓
Client receives response
Step 7: Update API Router
Update app/api/v1/api.py:
"""
API v1 router aggregator
Combines all v1 endpoint routers
"""
from fastapi import APIRouter
from app.api.v1.endpoints import users, health, users_advanced, dependencies_demo
# Create main v1 router
api_router = APIRouter()
# Include all endpoint routers
api_router.include_router(health.router)
api_router.include_router(users.router)
api_router.include_router(users_advanced.router)
api_router.include_router(dependencies_demo.router) # New!
# Future routers:
# api_router.include_router(ai.router)
# api_router.include_router(chat.router)
Step 8: Create Comprehensive Test Script
Create test_dependencies_middleware.py:
"""
Test script for dependencies and middleware
Tests dependency injection patterns and middleware functionality
"""
import requests
import json
import time
BASE_URL = "http://127.0.0.1:8000/api/v1"
def print_test(title: str, response: requests.Response):
"""Helper to print test results"""
print(f"\n{'='*70}")
print(f"{title}")
print(f"{'='*70}")
print(f"Status: {response.status_code}")
print(f"Headers:")
for key in ['X-Request-ID', 'X-Process-Time', 'Access-Control-Allow-Origin']:
if key in response.headers:
print(f" {key}: {response.headers[key]}")
try:
print(f"Body:\n{json.dumps(response.json(), indent=2, default=str)}")
except:
print(f"Body: {response.text}")
def test_timestamp_dependency():
"""Test simple dependency"""
response = requests.get(f"{BASE_URL}/dependencies/timestamp")
print_test("📅 TIMESTAMP DEPENDENCY", response)
def test_protected_endpoint():
"""Test authentication dependency"""
# Without token
response = requests.get(f"{BASE_URL}/dependencies/protected")
print_test("🔒 PROTECTED (NO TOKEN)", response)
# With valid token
headers = {"X-Token": "secret-token-123"}
response = requests.get(f"{BASE_URL}/dependencies/protected", headers=headers)
print_test("🔓 PROTECTED (VALID TOKEN)", response)
# With invalid token
headers = {"X-Token": "invalid-token"}
response = requests.get(f"{BASE_URL}/dependencies/protected", headers=headers)
print_test("❌ PROTECTED (INVALID TOKEN)", response)
def test_api_key_dependency():
"""Test API key dependency"""
response = requests.get(
f"{BASE_URL}/dependencies/api-key-protected",
params={"api_key": "key-12345"}
)
print_test("🔑 API KEY DEPENDENCY", response)
def test_common_query_params():
"""Test class-based dependency"""
response = requests.get(
f"{BASE_URL}/dependencies/items",
params={
"skip": 5,
"limit": 3,
"sort_by": "name",
"sort_order": "desc"
}
)
print_test("📊 COMMON QUERY PARAMS", response)
def test_rate_limiter():
"""Test rate limiter dependency"""
print(f"\n{'='*70}")
print("⏱️ RATE LIMITER TEST (5 requests/minute)")
print(f"{'='*70}")
for i in range(7):
response = requests.get(f"{BASE_URL}/dependencies/rate-limited")
status = "✅" if response.status_code == 200 else "❌"
print(f"Request {i+1}: {status} Status {response.status_code}")
if response.status_code == 429:
print(f" Rate limit message: {response.json()['detail']}")
time.sleep(0.5)
def test_request_context():
"""Test request context dependency"""
headers = {
"X-Request-ID": "custom-request-id-123",
"User-Agent": "Test-Client/1.0"
}
response = requests.get(
f"{BASE_URL}/dependencies/context",
headers=headers
)
print_test("🎯 REQUEST CONTEXT", response)
def test_database_dependency():
"""Test generator dependency with cleanup"""
print(f"\n{'='*70}")
print("💾 DATABASE DEPENDENCY (watch server logs for cleanup)")
print(f"{'='*70}")
response = requests.get(f"{BASE_URL}/dependencies/database")
print(f"Status: {response.status_code}")
print(f"Response: {json.dumps(response.json(), indent=2)}")
print("⚠️ Check server logs for 'Database session created/closed' messages")
def test_cached_dependency():
"""Test cached expensive dependency"""
print(f"\n{'='*70}")
print("💰 CACHED DEPENDENCY (watch server logs)")
print(f"{'='*70}")
for i in range(3):
response = requests.get(f"{BASE_URL}/dependencies/expensive")
print(f"\nCall {i+1}:")
print(f" Status: {response.status_code}")
print(f" Data: {response.json()['data']}")
if i == 0:
print(" ⚠️ First call: should see 'Computing expensive resource' in logs")
else:
print(" ✅ Subsequent call: should use cache (no log message)")
def test_background_tasks():
"""Test background tasks"""
response = requests.post(
f"{BASE_URL}/dependencies/signup",
params={
"email": "test@example.com",
"username": "testuser"
}
)
print_test("📧 BACKGROUND TASKS (check logs)", response)
print("⚠️ Check server logs for background task execution (takes ~3 seconds)")
def test_activity_logging():
"""Test activity logging background task"""
response = requests.post(
f"{BASE_URL}/dependencies/activity/123",
params={"action": "login"}
)
print_test("📝 ACTIVITY LOGGING", response)
def test_multiple_dependencies():
"""Test endpoint with multiple dependencies"""
headers = {
"X-Token": "secret-token-123",
"X-Request-ID": "multi-dep-test-123"
}
params = {
"skip": 0,
"limit": 5,
"sort_by": "name"
}
response = requests.get(
f"{BASE_URL}/dependencies/complex",
headers=headers,
params=params
)
print_test("🎭 MULTIPLE DEPENDENCIES", response)
def test_request_state():
"""Test accessing request state from middleware"""
response = requests.get(f"{BASE_URL}/dependencies/request-state")
print_test("🔍 REQUEST STATE (from middleware)", response)
def test_middleware_headers():
"""Test middleware-added headers"""
print(f"\n{'='*70}")
print("🔧 MIDDLEWARE HEADERS TEST")
print(f"{'='*70}")
response = requests.get(f"{BASE_URL}/health")
print("\nHeaders added by middleware:")
print(f" X-Request-ID: {response.headers.get('X-Request-ID', 'Not found')}")
print(f" X-Process-Time: {response.headers.get('X-Process-Time', 'Not found')} seconds")
print(f" Access-Control-Allow-Origin: {response.headers.get('Access-Control-Allow-Origin', 'Not found')}")
def test_slow_request():
"""Test performance monitoring (slow request detection)"""
print(f"\n{'='*70}")
print("🐌 SLOW REQUEST TEST (check server logs)")
print(f"{'='*70}")
print("Note: This endpoint doesn't exist, but watch for slow request warning in logs")
print(" (if it takes > 1 second)")
# This might be slow due to 404 handling
response = requests.get(f"{BASE_URL}/slow-endpoint-test")
print(f"Status: {response.status_code}")
print(f"Process time: {response.headers.get('X-Process-Time')} seconds")
def run_all_tests():
"""Run all tests"""
print("\n" + "🧪"*35)
print("DEPENDENCIES & MIDDLEWARE TEST SUITE")
print("🧪"*35)
print("\n" + "="*70)
print("SIMPLE DEPENDENCIES")
print("="*70)
test_timestamp_dependency()
print("\n" + "="*70)
print("AUTHENTICATION DEPENDENCIES")
print("="*70)
test_protected_endpoint()
test_api_key_dependency()
print("\n" + "="*70)
print("CLASS-BASED DEPENDENCIES")
print("="*70)
test_common_query_params()
print("\n" + "="*70)
print("ADVANCED DEPENDENCIES")
print("="*70)
test_request_context()
test_database_dependency()
test_cached_dependency()
print("\n" + "="*70)
print("BACKGROUND TASKS")
print("="*70)
test_background_tasks()
test_activity_logging()
print("\n" + "="*70)
print("COMPLEX SCENARIOS")
print("="*70)
test_multiple_dependencies()
test_request_state()
print("\n" + "="*70)
print("MIDDLEWARE TESTS")
print("="*70)
test_middleware_headers()
print("\n" + "="*70)
print("RATE LIMITING")
print("="*70)
test_rate_limiter()
print("\n" + "✅"*35)
print("ALL TESTS COMPLETED!")
print("✅"*35)
print("\n⚠️ Remember to check server logs for:")
print(" - Request logging with IDs")
print(" - Background task execution")
print(" - Database session cleanup")
print(" - Slow request warnings")
print(" - Cached dependency computation\n")
if __name__ == "__main__":
print("""
╔════════════════════════════════════════════════════════╗
║ Dependencies & Middleware Test Suite ║
║ ║
║ Make sure your server is running: ║
║ python main.py ║
║ ║
║ Watch the SERVER LOGS while running these tests! ║
╚════════════════════════════════════════════════════════╝
""")
try:
response = requests.get(f"{BASE_URL}/health")
if response.status_code == 200:
run_all_tests()
else:
print("❌ Server returned unexpected status code")
except requests.exceptions.ConnectionError:
print("❌ ERROR: Cannot connect to server!")
print(" Please start the server with: python main.py")
except Exception as e:
print(f"❌ ERROR: {e}")
📦 Installation Requirements
Don’t forget to install the JSON logging library:
pip install python-json-logger
pip freeze > requirements.txt
🎯 Implementation Checklist
New Files to Create:
app/
├── dependencies.py ✅ Central dependencies
├── middleware/
│ ├── __init__.py ✅
│ ├── logging_middleware.py ✅
│ ├── cors_middleware.py ✅
│ └── performance_middleware.py ✅
├── utils/
│ ├── __init__.py ✅
│ ├── logger.py ✅
│ └── background_tasks.py ✅
└── api/v1/endpoints/
└── dependencies_demo.py ✅
test_dependencies_middleware.py ✅
Files to Update:
app/main.py ✅ Add middleware
app/api/v1/api.py ✅ Add demo router
🧪 Running the Tests
Step 1: Start the Server
python main.py
Watch the console logs! You’ll see structured JSON logs for every request.
Step 2: Run Tests (in another terminal)
# Activate venv
source venv/bin/activate # or venv\Scripts\activate on Windows
# Run tests
python test_dependencies_middleware.py
What You’ll See:
In the test terminal:
- ✅ Successful requests with proper dependencies
- ❌ Failed auth attempts
- 🔄 Rate limiting in action
- 📧 Background task confirmation
In the server logs:
- 📝 Request logging with IDs
- 💾 Database session lifecycle
- 💰 Cached dependency (computed once)
- 📧 Background tasks executing
- ⏱️ Request timing information
📊 Understanding the Request Lifecycle
Complete Request Flow:
1. Client sends request
↓
2. PerformanceMiddleware (start timer)
↓
3. LoggingMiddleware (log incoming request, add request ID)
↓
4. CORSMiddleware (handle CORS headers)
↓
5. FastAPI routing (match endpoint)
↓
6. Dependency resolution
├── get_current_timestamp() → executed
├── verify_token() → executed
├── CommonQueryParams() → instantiated
└── get_db_session() → setup phase
↓
7. Your endpoint function executes
↓
8. Dependency cleanup
└── get_db_session() → cleanup phase
↓
9. Response creation
↓
10. Exception handlers (if error occurred)
↓
11. CORSMiddleware (add CORS headers)
↓
12. LoggingMiddleware (log response)
↓
13. PerformanceMiddleware (check timing, add headers)
↓
14. Response sent to client
↓
15. Background tasks execute (async)
🎓 Real-World Use Cases
1. Multi-Tenant SaaS Application
async def get_tenant(x_tenant_id: str = Header(...)):
"""Identify tenant from header"""
return await db.get_tenant(x_tenant_id)
async def get_tenant_db(tenant: Tenant = Depends(get_tenant)):
"""Connect to tenant-specific database"""
return Database(tenant.db_connection_string)
@app.get("/customers")
async def list_customers(
tenant: Tenant = Depends(get_tenant),
db: Database = Depends(get_tenant_db)
):
# Automatically queries correct tenant database
return await db.customers.list()
2. API Rate Limiting with Redis
python
class RedisRateLimiter:
def __init__(self, redis_client, max_requests: int = 100):
self.redis = redis_client
self.max_requests = max_requests
async def __call__(self, request: Request, user_id: int):
key = f"rate_limit:{user_id}"
current = await self.redis.incr(key)
if current == 1:
await self.redis.expire(key, 3600) # 1 hour window
if current > self.max_requests:
raise HTTPException(429, "Rate limit exceeded")
return True
3. Request Tracing Across Microservices
class TracingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
request.state.trace_id = trace_id
# Log with trace ID
logger.info("Request", extra={"trace_id": trace_id})
response = await call_next(request)
# Pass trace ID to next service
response.headers["X-Trace-ID"] = trace_id
return response
What is the main benefit of Dependency Injection (DI) in FastAPI?
How does FastAPI Middleware differ from Dependencies?
Why use structured (JSON) logging instead of standard text logs?
When should I use BackgroundTasks?
What is the “Application Factory” pattern’s role in middleware?
🎉 What You’ve Accomplished!
✅ Mastered Dependency Injection
- Simple function dependencies
- Class-based dependencies
- Generator dependencies with cleanup
- Cached dependencies
- Nested sub-dependencies
✅ Created Production Middleware
- Request logging with unique IDs
- Performance monitoring
- Custom CORS handling
- Request state management
✅ Implemented Structured Logging
- JSON-formatted logs
- Context-aware logging
- Request tracking
- Debugging-friendly output
✅ Async Background Tasks
- Non-blocking email sending
- Analytics processing
- Cleanup tasks
- Activity logging
✅ Professional Code Organization
- Reusable patterns
- Testable components
- Clear separation of concerns
- Production-ready architecture
🚀 Coming Up Next
Blog Post 6: AI Integration with Ollama
We’ll finally integrate AI into our backend:
- Installing and setting up Ollama locally
- Creating an AI service abstraction layer
- Implementing chat endpoints
- Streaming responses in real-time
- Model management and switching
- Building a simple HTML chat interface
You now have a solid FastAPI foundation. Time to make it AI-powered! 🤖
Leave a Reply