Deploying Machine Learning Models with FastAPI
Machine learning models are only valuable when they're accessible to users and applications. FastAPI makes it incredibly easy to expose your trained models as production-ready REST APIs with minimal overhead and maximum performance. In this guide, I'll walk you through deploying ML models using FastAPI, covering everything from basic setup to monitoring in production.
Why FastAPI for ML Deployment?
FastAPI offers several advantages specifically for ML deployment:
- Automatic API documentation with Swagger UI and ReDoc
- Built-in validation using Pydantic for request/response schemas
- Async support for handling concurrent requests efficiently
- Dependency injection for managing model loading and caching
- Type hints that improve code quality and IDE support
Project Structure
Here's the recommended structure for an ML deployment project:
ml-api/
├── main.py # FastAPI application
├── models/
│ ├── __init__.py
│ ├── model_loader.py # Model loading logic
│ └── predictor.py # Prediction logic
├── schemas/
│ ├── __init__.py
│ └── requests.py # Request/response schemas
├── requirements.txt
├── docker-compose.yml
└── config.py # Configuration management
Setting Up FastAPI with Models
Let's start with the basic FastAPI application structure:
from fastapi import FastAPI, BackgroundTasks
from contextlib import asynccontextmanager
import pickle
import numpy as np
from pydantic import BaseModel
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model cache
model = None
class PredictionRequest(BaseModel):
features: list[float]
model_type: str = "default"
class PredictionResponse(BaseModel):
prediction: float
confidence: float
latency_ms: float
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global model
logger.info("Loading model on startup...")
with open("models/trained_model.pkl", "rb") as f:
model = pickle.load(f)
logger.info("Model loaded successfully")
yield
# Shutdown
logger.info("Shutting down")
app = FastAPI(title="ML API", version="1.0.0", lifespan=lifespan)
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
start_time = time.time()
# Convert input to numpy array
features = np.array(request.features).reshape(1, -1)
# Make prediction
prediction = model.predict(features)[0]
confidence = float(model.predict_proba(features).max())
latency = (time.time() - start_time) * 1000
return PredictionResponse(
prediction=float(prediction),
confidence=confidence,
latency_ms=latency
)
Model Loading and Caching
Properly managing model lifecycle is crucial for production systems. Here's a robust approach:
from typing import Optional
import asyncio
from pathlib import Path
class ModelManager:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self._model = None
self._metadata = None
self._load_lock = asyncio.Lock()
async def get_model(self):
"""Get cached model, load if necessary"""
if self._model is None:
async with self._load_lock:
if self._model is None: # Double-check locking
self._load_model()
return self._model
def _load_model(self):
"""Load model from disk"""
logger.info(f"Loading model from {self.model_path}")
with open(self.model_path, "rb") as f:
self._model = pickle.load(f)
# Load metadata if available
metadata_path = self.model_path.with_suffix(".json")
if metadata_path.exists():
with open(metadata_path) as f:
self._metadata = json.load(f)
logger.info("Model loaded successfully")
async def reload_model(self):
"""Reload model (useful for updates)"""
async with self._load_lock:
self._model = None
await self.get_model()
Handling Batch Predictions
Many ML applications need to process multiple samples efficiently:
class BatchPredictionRequest(BaseModel):
samples: list[list[float]]
return_probabilities: bool = False
class BatchPredictionResponse(BaseModel):
predictions: list[float]
probabilities: Optional[list[list[float]]] = None
processing_time_ms: float
@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
start_time = time.time()
# Convert to numpy array
features = np.array(request.samples)
# Make predictions
predictions = model.predict(features)
response_data = {
"predictions": predictions.tolist(),
"processing_time_ms": (time.time() - start_time) * 1000
}
if request.return_probabilities:
response_data["probabilities"] = \
model.predict_proba(features).tolist()
return BatchPredictionResponse(**response_data)
Health Checks and Monitoring
Production systems need proper health check endpoints:
from datetime import datetime
import psutil
class HealthStatus(BaseModel):
status: str
timestamp: datetime
model_loaded: bool
memory_usage_mb: float
cpu_percent: float
@app.get("/health", response_model=HealthStatus)
async def health_check():
return HealthStatus(
status="healthy",
timestamp=datetime.utcnow(),
model_loaded=model is not None,
memory_usage_mb=psutil.Process().memory_info().rss / 1024 / 1024,
cpu_percent=psutil.cpu_percent(interval=0.1)
)
@app.get("/metrics")
async def metrics():
"""Prometheus-compatible metrics endpoint"""
return {
"predictions_total": prediction_counter,
"prediction_latency_seconds": prediction_latency,
"model_load_time_seconds": model_load_time
}
Docker Deployment
Containerizing your API ensures consistency across environments:
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD python -c "import requests; requests.get('http://localhost:8000/health')"
# Run FastAPI
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
Best Practices Summary
- Model Versioning: Store model versions and be able to switch between them
- Input Validation: Always validate inputs using Pydantic schemas
- Error Handling: Provide meaningful error messages for invalid requests
- Logging: Log predictions and system metrics for debugging
- Rate Limiting: Implement rate limiting to prevent abuse
- Caching: Cache predictions when appropriate to reduce latency
- Monitoring: Set up comprehensive monitoring and alerting
- Documentation: Leverage FastAPI's auto-generated docs
Conclusion
FastAPI provides an excellent foundation for deploying machine learning models to production. By following these practices—proper model management, comprehensive error handling, health checks, and monitoring—you'll have a robust system that can serve predictions reliably at scale.