Exercise 1
Question
Create a middleware generating function that puts a timeout into the context. The function should have one parameter, which is the number of milliseconds that a request is allowed to run. It should return a func(http.Handler) http.Handler
.
Solution
This is a little tricky, because we are writing a function that returns a function that returns a function (phew!). The +Timeout+ function looks like this:
func Timeout(ms int) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancelFunc := context.WithTimeout(ctx, time.Duration(ms)*time.Millisecond)
defer cancelFunc()
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
}
The signature matches the description from the problem. The body declares two nested functions. The outer function is func(h http.Handler) http.Handler
. This function returns a func(w http.ResponseWriter, r *http.Request)
that is converted to the http.HandlerFunc
type, which meets the http.Handler
interface.
The body of the inner function does all the context work. We extract the context from the request, wrap it in a context generated by +context.WithTimeout+, call the +cancelFunc+ with a +defer+, construct a replacement *http.Request
using the new context and the old request with the WithContext
method on the old request, and then call the ServeHTTP
method on the http.Handler
that's passed in to the middleware.
Notice that there's nothing in this code to ensure the timeout is respected. That is the responsibility of the request handler and the business logic.
The request handler should have code that looks something like this:
func sleepy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
message, err := doThing(ctx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
w.WriteHeader(http.StatusGatewayTimeout)
} else {
w.WriteHeader(http.StatusInternalServerError)
}
} else {
w.WriteHeader(http.StatusOK)
}
w.Write([]byte(message))
}
And the business logic should include code that checks the context to make sure that its work hasn't taken too much time:
func doThing(ctx context.Context) (string, error) {
wait := rand.Intn(200)
select {
case <-time.After(time.Duration(wait) * time.Millisecond):
return "Done!", nil
case <-ctx.Done():
return "Too slow!", ctx.Err()
}
}