Skip to content
Snippets Groups Projects
Commit 59fe66e8 authored by Rohith's avatar Rohith
Browse files

- fixing up to use github.com/satori/go.uuid instead of and internal one, lose 20ns but hey :-)

parent 93bec30b
No related branches found
No related tags found
No related merge requests found
......@@ -41,7 +41,7 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler {
// @step: add the proxy forwarding headers
req.Header.Add("X-Forwarded-For", realIP(req))
req.Header.Set("X-Forwarded-Host", req.URL.Host)
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("X-Forwarded-Proto", req.Header.Get("X-Forwarded-Proto"))
// @step: add any custom headers to the request
......
......@@ -26,6 +26,7 @@ import (
"github.com/PuerkitoBio/purell"
"github.com/gambol99/go-oidc/jose"
"github.com/go-chi/chi/middleware"
uuid "github.com/satori/go.uuid"
"github.com/unrolled/secure"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
......@@ -71,7 +72,12 @@ func (r *oauthProxy) requestIDMiddleware(header string) func(http.Handler) http.
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if v := req.Header.Get(header); v == "" {
req.Header.Set(header, randomUUID())
uid, err := uuid.NewV1()
if err != nil {
r.log.Error("failed to generatet correlation id for request", zap.Error(err))
} else {
req.Header.Set(header, uid.String())
}
}
next.ServeHTTP(w, req)
......
......@@ -164,6 +164,7 @@ func (r *oauthProxy) createReverseProxy() error {
engine.Use(middleware.Recoverer)
// @check if the request tracking id middleware is enabled
if r.config.EnableRequestID {
r.log.Info("enabled the correlation request id middlware")
engine.Use(r.requestIDMiddleware(r.config.RequestIDHeader))
}
// @step: enable the entrypoint middleware
......
......@@ -28,7 +28,6 @@ import (
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"net"
"net/http"
"net/url"
......@@ -78,55 +77,6 @@ func getRequestHostURL(r *http.Request) string {
return fmt.Sprintf("%s://%s", scheme, hostname)
}
const (
letterBytes = "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz0123456789"
letterIdxBits = 6
letterIdxMask = 1<<letterIdxBits - 1
letterIdxMax = 63 / letterIdxBits
)
var randomSource = mrand.NewSource(time.Now().UnixNano())
// randomBytes returns a random array of bytes
// @note: code taken from https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-golang
func randomBytes(n int) []byte {
b := make([]byte, n)
for i, cache, remain := n-1, randomSource.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = randomSource.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return b
}
// randomString returns a random string of x length
func randomString(length int) string {
return string(randomBytes(length))
}
// randomUUID returns a uuid from the random string
func randomUUID() string {
uuid := make([]byte, 36)
r := randomBytes(32)
i := 0
for x := range []int{8, 4, 4, 4, 12} {
copy(uuid, r[i:i+x])
if x != 12 {
copy(uuid, []byte("-"))
i = i + x
}
}
return string(uuid)
}
// readConfigFile reads and parses the configuration file
func readConfigFile(filename string, config *Config) error {
content, err := ioutil.ReadFile(filename)
......
......@@ -26,7 +26,7 @@ import (
"testing"
"time"
"github.com/google/uuid"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
......@@ -65,45 +65,10 @@ func TestDecodeKeyPairs(t *testing.T) {
}
}
func TestRandom(t *testing.T) {
s := randomBytes(6)
assert.NotEmpty(t, s)
assert.Equal(t, 6, len(s))
}
func TestRandomString(t *testing.T) {
s := randomString(6)
assert.NotEmpty(t, s)
assert.Equal(t, 6, len(s))
}
func TestRandomUUID(t *testing.T) {
s := randomUUID()
assert.NotEmpty(t, s)
assert.Equal(t, 36, len(s))
}
func BenchmarkRandomBytes36(b *testing.B) {
for n := 0; n < b.N; n++ {
randomString(36)
}
}
func BenchmarkRandomString36(b *testing.B) {
for n := 0; n < b.N; n++ {
randomString(36)
}
}
func BenchmarkUUID(b *testing.B) {
for n := 0; n < b.N; n++ {
uuid.New()
}
}
func BenchmarkRandomUUID(b *testing.B) {
for n := 0; n < b.N; n++ {
randomUUID()
s, _ := uuid.NewV1()
s.String()
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment