shield

package
v0.1.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 1, 2026 License: MIT Imports: 14 Imported by: 0

README

shield — HTTP security middleware

shield provides composable HTTP middlewares for security headers, rate limiting, request tracing, body limits, and flash messages.

Quick start

// Create tables (BO only — FO receives them via dbsync).
shield.Init(db)

// Pre-built stacks.
foStack, mm := shield.DefaultFOStack(db)  // maintenance + headers + body limit + trace + rate limit + flash
mm.StartReloader(done)                    // poll maintenance flag every 5s
boStack := shield.DefaultBOStack()    // headers + body limit + trace + flash (no rate limit)

mux.Handle("/", foStack(handler))

Middlewares

Middleware Description
NewMaintenanceMode(db, excludes) 503 page when maintenance flag is active in SQLite
SecurityHeaders(cfg) CSP, X-Frame-Options, X-Content-Type-Options, Referrer-Policy, Permissions-Policy
MaxFormBody(maxBytes) Limit application/x-www-form-urlencoded body size
TraceID Generate random trace ID, set X-Trace-ID header, enrich context logger
NewRateLimiter(db, excludes) Per-IP rate limiting with rules from SQLite
Flash Read one-time flash cookie, inject into context
HeadToGet Forward HEAD requests as GET

Rate limiting

Rules are stored in a rate_limits SQLite table and reloaded every 60 s. Buckets are tracked in memory per IP + endpoint.

rl := shield.NewRateLimiter(db, []string{"/static/"})
rl.StartReloader(done)
mux.Handle("/api/", rl.Middleware()(handler))

API endpoints get a 429 JSON response with Retry-After header. Page endpoints get a flash-message redirect.

Flash messages

// Set a flash (cookie-based, one-time read).
shield.SetFlash(w, "success", "Saved!")

// Read in next request (via middleware).
flash := shield.GetFlash(r.Context())
// flash.Type = "success", flash.Message = "Saved!"

Exported API

Symbol Description
SecurityHeaders(cfg) Security header middleware
DefaultHeaders() Sensible default header config
MaxFormBody(max) Body size limit
TraceID Request tracing middleware
NewRateLimiter(db, excludes) SQLite-backed rate limiter
Flash Flash message middleware
SetFlash(w, type, msg) Write flash cookie
GetFlash(ctx) Read flash from context
HeadToGet HEAD → GET conversion
NewMaintenanceMode(db, excludes) SQLite-backed maintenance mode
Schema DDL for rate_limits + maintenance tables
Init(db) Create shield tables (idempotent)
DefaultFOStack(db) FO middleware stack (returns stack + MaintenanceMode)
DefaultBOStack() BO middleware stack

Documentation

Overview

Package shield provides reusable HTTP security middleware for HOROS services. It consolidates security headers, rate limiting, body limits, request tracing, flash messages, and HEAD method handling into a single importable package.

Usage:

r := chi.NewRouter()
r.Use(shield.SecurityHeaders(shield.DefaultHeaders()))
r.Use(shield.MaxFormBody(64 * 1024))
r.Use(shield.TraceID)
r.Use(shield.NewRateLimiter(db).Middleware)
r.Use(shield.Flash)
r.Use(shield.HeadToGet)

Or apply the default FO stack in one call:

stack, mm := shield.DefaultFOStack(db)
mm.StartReloader(done)
for _, mw := range stack {
    r.Use(mw)
}

Index

Constants

View Source
const (
	// LoggerKey is the context key for the per-request structured logger.
	LoggerKey contextKey = "shield_logger"

	// FlashKey is the context key for flash messages.
	FlashKey contextKey = "shield_flash"
)
View Source
const Schema = `` /* 552-byte string literal not displayed */

Schema defines the SQLite tables used by shield middlewares:

  • rate_limits: per-endpoint rate limiting rules (used by RateLimiter)
  • maintenance: global maintenance mode flag (used by MaintenanceMode)

Apply with Init(db) or execute manually. All statements are idempotent (CREATE IF NOT EXISTS). The BO should create these tables; dbsync replicates them to FO instances via FilterSpec.FullTables.

Variables

This section is empty.

Functions

func DefaultBOStack

func DefaultBOStack() []func(http.Handler) http.Handler

DefaultBOStack returns the standard middleware stack for a HOROS BO service. Same as FO but without rate limiting (BO is not publicly exposed).

func ExtractIP

func ExtractIP(r *http.Request) string

ExtractIP returns the client IP from X-Forwarded-For or RemoteAddr.

func Flash

func Flash(next http.Handler) http.Handler

Flash reads the "flash" cookie, parses the type prefix ("success:" or "error:"), stores the FlashMessage in the context under FlashKey, and clears the cookie.

func GetLogger

func GetLogger(ctx context.Context) *slog.Logger

GetLogger retrieves the per-request logger from the context. Returns slog.Default() if no logger was set.

func HeadToGet

func HeadToGet(next http.Handler) http.Handler

HeadToGet converts HEAD requests to GET so that route handlers registered with r.Get() respond with 200 instead of 405 (Method Not Allowed). Go's net/http automatically strips the body for HEAD responses.

func Init

func Init(db *sql.DB) error

Init creates the shield tables if they don't exist.

func MaxFormBody

func MaxFormBody(maxBytes int64) func(http.Handler) http.Handler

MaxFormBody returns middleware that limits the request body size for form-encoded POST requests. Other content types are passed through.

func SecurityHeaders

func SecurityHeaders(cfg HeaderConfig) func(http.Handler) http.Handler

SecurityHeaders returns middleware that sets the configured security headers on every response. Use DefaultHeaders() for the standard HOROS configuration.

func SetFlash

func SetFlash(w http.ResponseWriter, flashType, message string)

SetFlash sets a flash cookie with the given type and message. The cookie is HttpOnly and SameSite=Lax with a 10-second TTL.

func TraceID

func TraceID(next http.Handler) http.Handler

TraceID generates a random trace ID for each request and injects it into the context, response headers, and a per-request structured logger. The trace ID is stored under kit.TraceIDKey and the logger under LoggerKey.

Types

type FlashMessage

type FlashMessage struct {
	Type    string // "success" or "error"
	Message string
}

FlashMessage represents a one-time notification shown to the user.

func GetFlash

func GetFlash(ctx context.Context) *FlashMessage

GetFlash retrieves the flash message from the request context.

type HeaderConfig

type HeaderConfig struct {
	CSP                 string
	XFrameOptions       string
	XContentTypeOptions string
	ReferrerPolicy      string
	PermissionsPolicy   string
}

HeaderConfig defines the security headers applied to every response.

func DefaultHeaders

func DefaultHeaders() HeaderConfig

DefaultHeaders returns the standard HOROS security header configuration.

type MaintenanceMode

type MaintenanceMode struct {
	// contains filtered or unexported fields
}

MaintenanceMode provides a middleware that returns a 503 Service Unavailable page when maintenance mode is active. The flag is stored in a SQLite table (replicated to FO instances by dbsync) and cached in memory.

Expected schema:

CREATE TABLE IF NOT EXISTS maintenance (
    id INTEGER PRIMARY KEY CHECK (id = 1),
    active INTEGER NOT NULL DEFAULT 0,
    message TEXT NOT NULL DEFAULT 'Maintenance en cours, veuillez patienter.'
);

Only one row (id=1) is expected. If the table does not exist or is empty, maintenance mode is off.

func DefaultFOStack

func DefaultFOStack(db *sql.DB) ([]func(http.Handler) http.Handler, *MaintenanceMode)

DefaultFOStack returns the standard middleware stack for a HOROS FO service. Middleware is ordered: Maintenance → HeadToGet → SecurityHeaders → MaxFormBody → TraceID → RateLimiter → Flash. The returned MaintenanceMode handle allows callers to set a custom page and call StartReloader. Health checks (/healthz) bypass maintenance.

func NewMaintenanceMode

func NewMaintenanceMode(db *sql.DB, excludePrefixes ...string) *MaintenanceMode

NewMaintenanceMode creates a maintenance mode checker. Paths matching any of excludePrefixes are never blocked (useful for health checks, static assets).

func (*MaintenanceMode) Active

func (m *MaintenanceMode) Active() bool

Active reports whether maintenance mode is currently on.

func (*MaintenanceMode) Message

func (m *MaintenanceMode) Message() string

Message returns the current maintenance message.

func (*MaintenanceMode) Middleware

func (m *MaintenanceMode) Middleware(next http.Handler) http.Handler

Middleware returns an HTTP middleware that blocks requests with a 503 page when maintenance mode is active. Excluded prefixes pass through.

func (*MaintenanceMode) SetDB

func (m *MaintenanceMode) SetDB(db *sql.DB)

SetDB replaces the database connection and reloads the flag. Used in FO mode when the dbsync subscriber swaps the database.

func (*MaintenanceMode) SetPage

func (m *MaintenanceMode) SetPage(html []byte)

SetPage sets custom HTML to serve during maintenance. If not set, a minimal default page is used. The HTML is served as-is with Content-Type text/html.

func (*MaintenanceMode) StartReloader

func (m *MaintenanceMode) StartReloader(done <-chan struct{})

StartReloader starts a background goroutine that reloads the maintenance flag every 5 seconds. Stops when done is closed.

type RateLimitConfig

type RateLimitConfig struct {
	MaxRequests   int
	WindowSeconds int
	Enabled       bool
}

RateLimitConfig defines the rate limit for a single endpoint.

type RateLimiter

type RateLimiter struct {
	// contains filtered or unexported fields
}

RateLimiter provides per-IP, per-endpoint rate limiting backed by a SQLite rate_limits table. Rules are reloaded periodically and expired buckets are garbage collected.

Expected schema:

CREATE TABLE IF NOT EXISTS rate_limits (
    endpoint TEXT PRIMARY KEY,
    max_requests INTEGER NOT NULL DEFAULT 60,
    window_seconds INTEGER NOT NULL DEFAULT 60,
    enabled INTEGER NOT NULL DEFAULT 1
);

func NewRateLimiter

func NewRateLimiter(db *sql.DB, excludePrefixes ...string) *RateLimiter

NewRateLimiter creates a rate limiter that reads rules from the rate_limits table in db. Call StartReloader to enable periodic rule refresh and GC.

func (*RateLimiter) Middleware

func (rl *RateLimiter) Middleware(next http.Handler) http.Handler

Middleware is the HTTP middleware that enforces rate limits. API paths (/api/*) get a 429 JSON response; other paths get a redirect with a flash message.

func (*RateLimiter) SetDB

func (rl *RateLimiter) SetDB(db *sql.DB)

SetDB replaces the database connection and reloads rules. Used in FO mode when the dbsync subscriber swaps the database.

func (*RateLimiter) StartReloader

func (rl *RateLimiter) StartReloader(done <-chan struct{})

StartReloader starts background goroutines for rule reloading (every 60s) and bucket GC (every 5min). Stops when done is closed.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL