middleware

package
v0.1.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 28, 2026 License: MIT Imports: 17 Imported by: 0

Documentation

Overview

Package middleware provides built-in middleware for the wt WebTransport framework.

Middleware Ordering

Middleware executes in the order it's added. Global middleware runs before route-specific middleware. The handler runs last.

server.Use(A)  // runs 1st
server.Use(B)  // runs 2nd
server.Handle("/path", handler, C, D)
// Execution: A → B → C → D → handler → D-after → C-after → B-after → A-after
server.Use(middleware.Recover(nil))       // 1. Catch panics (outermost)
server.Use(middleware.RequestID())        // 2. Assign request ID
server.Use(middleware.DefaultLogger())    // 3. Log with request ID
server.Use(middleware.RateLimit(100))     // 4. Reject excess before auth
server.Use(middleware.BearerAuth(fn))     // 5. Authenticate
server.Use(middleware.MaxSessions(1000))  // 6. Global capacity

Available Middleware

Authentication:

  • BearerAuth — validate Bearer token from Authorization header
  • QueryAuth — validate token from query parameter
  • RequireKey — check static API key

Rate Limiting:

Observability:

Resilience:

Networking:

Session:

Origin Control:

  • CORS — origin validation

Package middleware provides built-in middleware for the wt framework.

Index

Examples

Constants

View Source
const TraceIDKey = "_trace_id"

TraceIDKey is the context key for trace IDs.

Variables

This section is empty.

Functions

func Bandwidth

func Bandwidth() wt.MiddlewareFunc

Bandwidth returns middleware that installs a bandwidth tracker on each session.

func BearerAuth

func BearerAuth(validate func(token string) (any, error)) wt.MiddlewareFunc

BearerAuth returns middleware that validates Bearer tokens from the request header. The validate function receives the token and returns the user identity (stored in context as "user") or an error to reject the connection.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	validateToken := func(token string) (any, error) {
		if token == "valid" {
			return "user-123", nil
		}
		return nil, fmt.Errorf("invalid")
	}

	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.BearerAuth(validateToken))
	fmt.Println("Auth middleware added")
}
Output:
Auth middleware added

func BlockUserAgent

func BlockUserAgent(blocked ...string) wt.MiddlewareFunc

BlockUserAgent returns middleware that rejects connections from clients with specific User-Agent strings. Useful for blocking known bad bots or specific client versions.

func CORS

func CORS(config CORSConfig) wt.MiddlewareFunc

CORS returns middleware that validates the Origin header on the WebTransport handshake.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.CORS(middleware.CORSConfig{
		AllowedOrigins: []string{"https://example.com"},
	}))
	fmt.Println("CORS configured")
}
Output:
CORS configured

func CheckDepth

func CheckDepth(c *wt.Context) bool

CheckDepth should be called before spawning a new stream handler goroutine. Returns true if under the limit, false if the limit is reached.

func Compress

func Compress(c Compressor, logger *slog.Logger) wt.MiddlewareFunc

Compress returns middleware that stores a compressor in the context for handlers to use when sending large messages.

func DefaultLogger

func DefaultLogger() wt.MiddlewareFunc

DefaultLogger returns a Logger middleware using slog.Default().

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.DefaultLogger())
	fmt.Println("Logger middleware added")
}
Output:
Logger middleware added

func DepthGuard

func DepthGuard(maxConcurrent int) wt.MiddlewareFunc

DepthGuard limits the number of concurrent stream handlers per session. Prevents a single client from opening too many streams and overwhelming the server's goroutine pool.

Usage:

server.Handle("/app", handler, middleware.DepthGuard(50))
// Each session limited to 50 concurrent stream handlers

func ExtractHeader

func ExtractHeader(headerName, contextKey string) wt.MiddlewareFunc

ExtractHeader extracts a specific HTTP header from the WebTransport handshake request and stores it in the session context. Useful for extracting custom headers like X-Request-ID, X-Forwarded-For, etc.

func ExtractHeaders

func ExtractHeaders(keys map[string]string) wt.MiddlewareFunc

ExtractHeaders extracts multiple headers into context. keys maps header name → context key.

func GetLogger

func GetLogger(c *wt.Context) *slog.Logger

GetLogger retrieves the session-scoped logger from the context. Returns slog.Default() if SlogAttrs middleware wasn't applied.

func GetRateLimiter

func GetRateLimiter(c *wt.Context) *tokenBucket

GetRateLimiter retrieves the token bucket rate limiter from the context. Returns nil if TokenBucket middleware wasn't applied.

func GetRequestID

func GetRequestID(c *wt.Context) string

GetRequestID retrieves the request ID from the context.

func GetTraceContext

func GetTraceContext(c *wt.Context) context.Context

GetTraceContext retrieves the trace context for creating sub-spans.

func GetTraceID

func GetTraceID(c *wt.Context) string

GetTraceID retrieves the trace ID from a context.

func IPWhitelist

func IPWhitelist(allowed ...string) wt.MiddlewareFunc

IPWhitelist returns middleware that only allows connections from the given IPs or CIDRs. Useful for admin endpoints or internal services.

Usage:

server.Handle("/admin", handler, middleware.IPWhitelist("10.0.0.0/8", "192.168.1.0/24", "127.0.0.1"))

func IdleTimeout

func IdleTimeout(d time.Duration) wt.MiddlewareFunc

IdleTimeout returns middleware that closes sessions if no datagrams are received within the given duration. Resets on each datagram. Note: This wraps the handler — it doesn't intercept individual datagrams. For production use, implement idle detection within your handler.

func Logger

func Logger(logger *slog.Logger) wt.MiddlewareFunc

Logger returns middleware that logs session lifecycle events.

func MaxSessions

func MaxSessions(max int, logger *slog.Logger) wt.MiddlewareFunc

MaxSessions returns middleware that limits the total number of concurrent sessions.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.MaxSessions(1000, nil))
	fmt.Println("Max sessions set to 1000")
}
Output:
Max sessions set to 1000

func OTelTracing

func OTelTracing(tracer Tracer) wt.MiddlewareFunc

OTelTracing returns middleware that creates trace spans for each session using the provided Tracer interface. This allows OpenTelemetry integration without importing OTel as a dependency of the framework.

Usage with OpenTelemetry:

type myTracer struct {
    tracer trace.Tracer
}

func (t *myTracer) StartSpan(ctx context.Context, op string) (context.Context, middleware.TraceSpan) {
    ctx, span := t.tracer.Start(ctx, op)
    return ctx, &mySpan{span}
}

server.Use(middleware.OTelTracing(&myTracer{tracer: otel.Tracer("wt")}))

func PerPathRateLimit

func PerPathRateLimit(maxPerPath int) wt.MiddlewareFunc

PerPathRateLimit creates a rate limiter that tracks limits per unique path. Useful when applied globally but wanting different limits per resolved route.

Usage:

limiter := middleware.PerPathRateLimit(50)
server.Use(limiter)
// Each unique path (e.g., /chat/room1, /chat/room2) gets up to 50 concurrent sessions

func QueryAuth

func QueryAuth(param string, validate func(token string) (any, error)) wt.MiddlewareFunc

QueryAuth returns middleware that validates tokens from a query parameter. Useful when headers can't be set (e.g., browser WebTransport API).

func RateLimit

func RateLimit(maxPerIP int) wt.MiddlewareFunc

RateLimit returns middleware that limits the number of concurrent sessions per IP.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.RateLimit(100)) // max 100 connections per IP
	fmt.Println("Rate limit set to 100/IP")
}
Output:
Rate limit set to 100/IP

func Recover

func Recover(logger *slog.Logger) wt.MiddlewareFunc

Recover returns middleware that recovers from panics in session handlers.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.Recover(nil))
	fmt.Println("Recover middleware added")
}
Output:
Recover middleware added

func ReleaseDepth

func ReleaseDepth(c *wt.Context)

ReleaseDepth should be called when a stream handler goroutine completes.

func RequestID

func RequestID() wt.MiddlewareFunc

RequestID assigns a unique request ID to each session and stores it in context. Useful for log correlation across distributed systems.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.RequestID())
	fmt.Println("RequestID middleware added")
}
Output:
RequestID middleware added

func RequireKey

func RequireKey(headerName, expectedKey string) wt.MiddlewareFunc

RequireKey returns middleware that checks for a static API key in a header.

func RequireUserAgent

func RequireUserAgent() wt.MiddlewareFunc

RequireUserAgent returns middleware that rejects connections without a User-Agent header.

func RouteRateLimit

func RouteRateLimit(maxConcurrent int) wt.MiddlewareFunc

RouteRateLimit returns middleware that limits concurrent sessions per route path. Different from RateLimit (which limits per IP), this limits by the route pattern.

Usage:

server.Handle("/heavy/{id}", handler, middleware.RouteRateLimit(10))
// Only 10 concurrent sessions allowed on /heavy/*

func SessionData

func SessionData() wt.MiddlewareFunc

SessionData returns middleware that initializes common session metadata from the HTTP handshake request (user agent, origin, query params). This saves handlers from having to extract these individually.

func SessionTimeoutWithWarning

func SessionTimeoutWithWarning(timeout, warningBefore time.Duration, logger *slog.Logger) wt.MiddlewareFunc

SessionTimeoutWithWarning returns middleware that closes sessions after the given duration, but sends a warning datagram before closing. This gives clients a chance to save state or reconnect.

func SlogAttrs

func SlogAttrs() wt.MiddlewareFunc

SlogAttrs returns middleware that adds common session attributes to the slog default logger for all log calls within the handler. This means any slog.Info/Warn/Error call inside the handler will automatically include session_id, remote_addr, and path.

func Timeout

func Timeout(d time.Duration) wt.MiddlewareFunc

Timeout returns middleware that closes sessions after the given duration. Useful for preventing zombie sessions.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	// Sessions automatically close after 30 minutes
	server.Use(middleware.Timeout(30 * 60 * 1e9)) // 30 minutes
	fmt.Println("Timeout set")
}
Output:
Timeout set

func TokenBucket

func TokenBucket(rate float64, burst int) wt.MiddlewareFunc

TokenBucket returns middleware that rate-limits datagrams/messages using a token bucket. This middleware stores a rate limiter in the context under the key "_ratelimiter".

func TraceIDFromContext

func TraceIDFromContext(ctx context.Context) string

TraceIDFromContext retrieves a trace ID from a Go context.

func Tracing

func Tracing(logger *slog.Logger) wt.MiddlewareFunc

Tracing returns middleware that adds basic tracing to sessions. It generates a trace ID and logs session start/end with timing.

For OpenTelemetry integration, wrap this with your own middleware that creates OTel spans using the trace ID from context.

Example
package main

import (
	"fmt"
	"log/slog"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(middleware.Tracing(slog.Default()))
	fmt.Println("Tracing middleware added")
}
Output:
Tracing middleware added

func Webhook

func Webhook(url string, logger *slog.Logger) wt.MiddlewareFunc

Webhook returns middleware that POSTs session events to an external URL. Useful for analytics, logging pipelines, or triggering external workflows.

Usage:

server.Use(middleware.Webhook("https://hooks.example.com/wt-events"))

func WithTraceContext

func WithTraceContext(ctx context.Context, c *wt.Context) context.Context

WithTraceContext returns a Go context.Context with the trace ID attached. Useful for passing trace context to downstream services.

Types

type BandwidthTracker

type BandwidthTracker struct {
	BytesSent     atomic.Int64
	BytesReceived atomic.Int64
}

BandwidthTracker tracks bytes sent and received per session. Access the tracker via GetBandwidthTracker(c) within handlers.

func GetBandwidthTracker

func GetBandwidthTracker(c *wt.Context) *BandwidthTracker

GetBandwidthTracker retrieves the bandwidth tracker from the context.

func (*BandwidthTracker) RecordReceived

func (bt *BandwidthTracker) RecordReceived(n int64)

RecordReceived records bytes received.

func (*BandwidthTracker) RecordSent

func (bt *BandwidthTracker) RecordSent(n int64)

RecordSent records bytes sent.

func (*BandwidthTracker) Stats

func (bt *BandwidthTracker) Stats() (sent, received int64)

Stats returns current bandwidth stats.

type CORSConfig

type CORSConfig struct {
	// AllowedOrigins is the list of allowed origins.
	// Use "*" to allow all origins (not recommended for production).
	AllowedOrigins []string
}

CORSConfig configures origin validation.

type CircuitBreaker

type CircuitBreaker struct {
	// contains filtered or unexported fields
}

CircuitBreaker implements the circuit breaker pattern for WebTransport sessions. When too many sessions fail (panic/error), the breaker opens and rejects new connections. After a cooldown period, it enters half-open state and allows limited connections to test recovery.

func NewCircuitBreaker

func NewCircuitBreaker(threshold int, cooldown time.Duration) *CircuitBreaker

NewCircuitBreaker creates a circuit breaker. threshold: number of failures before opening. cooldown: time to wait in open state before testing.

func (*CircuitBreaker) Middleware

func (cb *CircuitBreaker) Middleware() wt.MiddlewareFunc

Middleware returns wt middleware that applies the circuit breaker.

func (*CircuitBreaker) Reset

func (cb *CircuitBreaker) Reset()

Reset forces the circuit back to closed state.

func (*CircuitBreaker) State

func (cb *CircuitBreaker) State() CircuitState

State returns the current circuit state.

type CircuitState

type CircuitState int

CircuitState represents the state of a circuit breaker.

const (
	CircuitClosed   CircuitState = iota // Normal operation
	CircuitOpen                         // Rejecting connections
	CircuitHalfOpen                     // Testing with limited connections
)

type Compressor

type Compressor interface {
	Compress(data []byte) ([]byte, error)
	Decompress(data []byte) ([]byte, error)
	Name() string
}

Compressor defines the interface for message compression.

func GetCompressor

func GetCompressor(c *wt.Context) Compressor

GetCompressor retrieves the compressor from the context. Returns nil if no compression middleware was applied.

type ConcurrencyStats

type ConcurrencyStats struct {
	ActiveSessions atomic.Int64
	PeakSessions   atomic.Int64
	TotalAccepted  atomic.Int64
	TotalRejected  atomic.Int64
}

ConcurrencyStats tracks global server concurrency metrics in real-time. Lighter than PrometheusMetrics — just atomic counters, no HTTP handler.

func NewConcurrencyStats

func NewConcurrencyStats() *ConcurrencyStats

NewConcurrencyStats creates a new concurrency tracker.

func (*ConcurrencyStats) Middleware

func (cs *ConcurrencyStats) Middleware() wt.MiddlewareFunc

Middleware returns wt middleware that tracks concurrency.

func (*ConcurrencyStats) Snapshot

func (cs *ConcurrencyStats) Snapshot() struct {
	Active   int64
	Peak     int64
	Accepted int64
	Rejected int64
}

Snapshot returns current stats.

type DeflateCompressor

type DeflateCompressor struct {
	// contains filtered or unexported fields
}

DeflateCompressor implements deflate compression.

func NewDeflateCompressor

func NewDeflateCompressor() *DeflateCompressor

NewDeflateCompressor creates a deflate compressor.

func (*DeflateCompressor) Compress

func (d *DeflateCompressor) Compress(data []byte) ([]byte, error)

func (*DeflateCompressor) Decompress

func (d *DeflateCompressor) Decompress(data []byte) ([]byte, error)

func (*DeflateCompressor) Name

func (d *DeflateCompressor) Name() string

type GzipCompressor

type GzipCompressor struct {
	// contains filtered or unexported fields
}

GzipCompressor implements gzip compression.

func NewGzipCompressor

func NewGzipCompressor() *GzipCompressor

NewGzipCompressor creates a gzip compressor with writer pooling.

func (*GzipCompressor) Compress

func (g *GzipCompressor) Compress(data []byte) ([]byte, error)

func (*GzipCompressor) Decompress

func (g *GzipCompressor) Decompress(data []byte) ([]byte, error)

func (*GzipCompressor) Name

func (g *GzipCompressor) Name() string

type IPBlacklist

type IPBlacklist struct {
	// contains filtered or unexported fields
}

IPBlacklist returns middleware that blocks connections from specific IPs or CIDRs. The blacklist can be updated at runtime via Add/Remove methods.

func NewIPBlacklist

func NewIPBlacklist(initial ...string) *IPBlacklist

NewIPBlacklist creates a new runtime-updatable IP blacklist.

func (*IPBlacklist) Add

func (bl *IPBlacklist) Add(ipOrCIDR string)

Add adds an IP or CIDR to the blacklist.

func (*IPBlacklist) IsBlocked

func (bl *IPBlacklist) IsBlocked(ip net.IP) bool

IsBlocked checks if an IP is blacklisted.

func (*IPBlacklist) Middleware

func (bl *IPBlacklist) Middleware() wt.MiddlewareFunc

Middleware returns wt middleware that blocks blacklisted IPs.

func (*IPBlacklist) Remove

func (bl *IPBlacklist) Remove(ip string)

Remove removes an IP from the blacklist (does not support removing CIDRs).

type IdleMonitor

type IdleMonitor struct {
	// contains filtered or unexported fields
}

IdleMonitor provides per-session idle detection. The idle timer resets every time Activity() is called. When the timer expires, the session is closed.

func NewIdleMonitor

func NewIdleMonitor(c *wt.Context, timeout time.Duration, onIdle func(*wt.Context)) *IdleMonitor

NewIdleMonitor creates a monitor that closes the session after the given idle duration. Call Activity() from your handler whenever the session is active (received data, etc.).

func (*IdleMonitor) Activity

func (im *IdleMonitor) Activity()

Activity resets the idle timer.

func (*IdleMonitor) Stop

func (im *IdleMonitor) Stop()

Stop cancels the idle monitor.

type LogTracer

type LogTracer struct {
	Logger *slog.Logger
}

LogTracer is a tracer that logs spans via slog. Useful for development.

func (*LogTracer) StartSpan

func (lt *LogTracer) StartSpan(ctx context.Context, operation string) (context.Context, TraceSpan)

type Metrics

type Metrics struct {
	ActiveSessions   atomic.Int64
	TotalSessions    atomic.Int64
	TotalDatagrams   atomic.Int64
	SessionDurations sync.Map // sessionID -> start time
}

Metrics tracks server-level metrics for monitoring.

func NewMetrics

func NewMetrics() *Metrics

NewMetrics creates a new Metrics instance.

func (*Metrics) Middleware

func (m *Metrics) Middleware() wt.MiddlewareFunc

Middleware returns a middleware that tracks session metrics.

func (*Metrics) SessionDuration

func (m *Metrics) SessionDuration(sessionID string) time.Duration

SessionDuration returns how long a session has been active.

func (*Metrics) Snapshot

func (m *Metrics) Snapshot() MetricsSnapshot

Snapshot returns current metrics values.

type MetricsSnapshot

type MetricsSnapshot struct {
	ActiveSessions int64
	TotalSessions  int64
	TotalDatagrams int64
}

Snapshot returns a point-in-time copy of the metrics.

type NoopTracer

type NoopTracer struct{}

NoopTracer is a tracer that does nothing. Useful for testing.

func (NoopTracer) StartSpan

func (NoopTracer) StartSpan(ctx context.Context, _ string) (context.Context, TraceSpan)

type PrometheusMetrics

type PrometheusMetrics struct {
	// contains filtered or unexported fields
}

PrometheusMetrics exports metrics in Prometheus text format. Serve the Handler() on an HTTP endpoint for Prometheus scraping.

func NewPrometheusMetrics

func NewPrometheusMetrics() *PrometheusMetrics

NewPrometheusMetrics creates a new Prometheus-compatible metrics collector.

Example
package main

import (
	"fmt"

	"github.com/rarebek/wt"
	"github.com/rarebek/wt/middleware"
)

func main() {
	pm := middleware.NewPrometheusMetrics()

	server := wt.New(wt.WithAddr(":4433"), wt.WithSelfSignedTLS())
	server.Use(pm.Middleware())

	// Serve metrics on a separate port
	// go http.ListenAndServe(":9090", pm.Handler())

	fmt.Println("Prometheus metrics enabled")
}
Output:
Prometheus metrics enabled

func (*PrometheusMetrics) Handler

func (pm *PrometheusMetrics) Handler() http.Handler

Handler returns an HTTP handler that serves Prometheus metrics.

Usage:

pm := middleware.NewPrometheusMetrics()
server.Use(pm.Middleware())
http.Handle("/metrics", pm.Handler())

func (*PrometheusMetrics) Middleware

func (pm *PrometheusMetrics) Middleware() wt.MiddlewareFunc

Middleware returns wt middleware that tracks session metrics.

type TraceSpan

type TraceSpan interface {
	// SetAttribute sets a key-value attribute on the span.
	SetAttribute(key string, value any)
	// SetStatus marks the span as error or ok.
	SetStatus(err error)
	// End completes the span.
	End()
}

TraceSpan represents an active trace span.

func GetTraceSpan

func GetTraceSpan(c *wt.Context) TraceSpan

GetTraceSpan retrieves the active trace span from the context. Returns nil if OTelTracing middleware wasn't applied.

type Tracer

type Tracer interface {
	// StartSpan starts a new trace span and returns the span context.
	// The span should be named with the given operation name.
	StartSpan(ctx context.Context, operation string) (context.Context, TraceSpan)
}

Tracer is an interface for distributed tracing integration. Implement this with your tracing library (OpenTelemetry, Jaeger, etc.) without importing those libraries as dependencies of the wt framework.

type WebhookEvent

type WebhookEvent struct {
	Event      string `json:"event"` // "connect", "disconnect"
	SessionID  string `json:"session_id"`
	RemoteAddr string `json:"remote_addr"`
	Path       string `json:"path"`
	Timestamp  int64  `json:"timestamp"`
}

WebhookEvent is sent to an external URL on session lifecycle events.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL