diff --git a/go.mod b/go.mod index 4a27d78..95ced80 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/justinas/alice v1.2.0 github.com/kr/text v0.2.0 // indirect + github.com/rs/xid v1.3.0 github.com/rs/zerolog v1.26.1 github.com/stretchr/testify v1.7.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/logginghandler.go b/logginghandler.go index d4dd0a6..114d3ad 100644 --- a/logginghandler.go +++ b/logginghandler.go @@ -8,11 +8,17 @@ import ( "time" "github.com/justinas/alice" + "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" "github.com/rs/zerolog/log" ) +const ( + UUIDKey = "uuid" + UUIDHeader = "X-Request-ID" +) + func init() { //nolint:gochecknoinits zerolog.DefaultContextLogger = &log.Logger } @@ -39,6 +45,43 @@ func FromCtx(ctx context.Context) *zerolog.Logger { return zerolog.Ctx(ctx) } +// RequestIDHandler looks in the header for an existing request id. Else it will create one. +func RequestIDHandler() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := r.Header.Get(UUIDHeader) + + if id != "" { + ctx := r.Context() + + log := zerolog.Ctx(ctx) + + uuid, err := xid.FromString(id) + if err != nil { + log.Error().Err(err).Msg("couldnt parse uuid") + + hlog.RequestIDHandler(UUIDKey, UUIDHeader)(next).ServeHTTP(w, r) + + return + } + + ctx = hlog.CtxWithID(ctx, uuid) + r = r.WithContext(ctx) + + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(UUIDKey, uuid.String()) + }) + + w.Header().Set(UUIDHeader, uuid.String()) + + next.ServeHTTP(w, r) + } else { + hlog.RequestIDHandler(UUIDKey, UUIDHeader)(next).ServeHTTP(w, r) + } + }) + } +} + func Handler(log zerolog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { chain := alice.New( @@ -56,7 +99,7 @@ func Handler(log zerolog.Logger) func(http.Handler) http.Handler { hlog.RemoteAddrHandler("remote"), hlog.UserAgentHandler("user-agent"), hlog.RefererHandler("referer"), - hlog.RequestIDHandler("uuid", "X-Request-ID"), + RequestIDHandler(), ).Then(next) return chain diff --git a/logginghandler_test.go b/logginghandler_test.go index b35cc50..acdd301 100644 --- a/logginghandler_test.go +++ b/logginghandler_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "go.xsfx.dev/logginghandler" @@ -18,13 +19,17 @@ import ( func Example() { logger := log.With().Logger() - handler := logginghandler.Handler(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger := logginghandler.FromRequest(r) + handler := logginghandler.Handler( + logger, + )( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger := logginghandler.FromRequest(r) - logger.Info().Msg("this is a request") + logger.Info().Msg("this is a request") - w.WriteHeader(http.StatusOK) - })) + w.WriteHeader(http.StatusOK) + }), + ) http.Handle("/", handler) log.Fatal().Msg(http.ListenAndServe(":5000", nil).Error()) @@ -45,7 +50,7 @@ func TestUUID(t *testing.T) { handler.ServeHTTP(rr, req) - assert.NotEmpty(rr.Header().Get("X-Request-ID")) + assert.NotEmpty(rr.Header().Get(logginghandler.UUIDHeader)) } func TestFromCtx(t *testing.T) { @@ -58,10 +63,14 @@ func TestFromCtx(t *testing.T) { var output bytes.Buffer rr := httptest.NewRecorder() - handler := logginghandler.Handler(zerolog.New(&output))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log := logginghandler.FromCtx(r.Context()) - log.Info().Msg("hello world") - })) + handler := logginghandler.Handler( + zerolog.New(&output), + )( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log := logginghandler.FromCtx(r.Context()) + log.Info().Msg("hello world") + }), + ) handler.ServeHTTP(rr, req) @@ -75,3 +84,69 @@ func TestFromCtx(t *testing.T) { assert.NotEmpty(jOut) } + +func TestRequestIDHandler(t *testing.T) { + t.Parallel() + assert := require.New(t) + + handler := logginghandler.RequestIDHandler()( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log := hlog.FromRequest(r) + log.Info().Msg("hello from TestRequestID") + }), + ) + + id := "cfrj1ro330reqgvfpgu0" + + // Create buffer to store output. + var output bytes.Buffer + + req, err := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + assert.NoError(err) + + rr := httptest.NewRecorder() + + h := hlog.NewHandler(zerolog.New(&output))(handler) + + h.ServeHTTP(rr, req) + + assert.NotEmpty(rr.Header().Get(logginghandler.UUIDHeader)) + assert.NotEqual(rr.Header().Get(logginghandler.UUIDHeader), id) + + // Now test with request id in header. + + nr := httptest.NewRecorder() + + nReq, err := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + assert.NoError(err) + + nReq.Header.Add(logginghandler.UUIDHeader, id) + + h.ServeHTTP(nr, nReq) + + assert.NotEmpty(nr.Header().Get(logginghandler.UUIDHeader)) + assert.Equal(nr.Header().Get(logginghandler.UUIDHeader), id) + + logs := strings.Split(output.String(), "\n") + assert.Len(logs, 3) + + getUUID := func(l string) (string, error) { + var out struct{ UUID string } + + err := json.Unmarshal([]byte(l), &out) + if err != nil { + return "", err + } + + return out.UUID, nil + } + + uuid1, err := getUUID(logs[0]) + assert.NoError(err) + + assert.NotEqual(id, uuid1) + + uuid2, err := getUUID(logs[1]) + assert.NoError(err) + assert.Equal(id, uuid2) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 4884cad..5d42433 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -9,6 +9,7 @@ github.com/justinas/alice # github.com/pmezard/go-difflib v1.0.0 github.com/pmezard/go-difflib/difflib # github.com/rs/xid v1.3.0 +## explicit github.com/rs/xid # github.com/rs/zerolog v1.26.1 ## explicit