Skip to content
Snippets Groups Projects
Commit 93bec30b authored by Rohith's avatar Rohith
Browse files

Request ID

- adding a request id middleware injection via the --enable-request-id
- fixing up the Makefile
parent fb74ab05
Branches
Tags
No related merge requests found
......@@ -3,6 +3,7 @@
FEATURES:
* Added the ability to use a "any" operation on the roles rather then just "and" with the inclusion of a `require-any-role` [#PR389](https://github.com/gambol99/keycloak-proxy/pull/389)
* Added a `--enable-request-id` option to inject a request id into the upstream request [#PR392](https://github.com/gambol99/keycloak-proxy/pull/392)
#### **2.2.2**
......
......@@ -83,6 +83,12 @@
packages = ["proto"]
revision = "1643683e1b54a9e88ad26d98f81400c8c9d9f4f9"
[[projects]]
name = "github.com/google/uuid"
packages = ["."]
revision = "064e2069ce9c359c118179501254f67d7d37ba24"
version = "0.2"
[[projects]]
name = "github.com/jonboulle/clockwork"
packages = ["."]
......@@ -240,6 +246,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "b139b3100a36c39eb42630da4afe77e4f69afef6fbf7b03c7f12e5608d317b04"
inputs-digest = "417f80c4b978ef6044e98c07cae1c0a72869108028ea5bc22305a8c1e7312acb"
solver-name = "gps-cdcl"
solver-version = 1
......@@ -122,7 +122,7 @@ format:
bench:
@echo "--> Running go bench"
@go test -bench=.
@go test -bench=. -benchmem
coverage:
@echo "--> Running go coverage"
......@@ -134,7 +134,7 @@ cover:
@go test --cover
spelling:
@echo "--> Chekcing the spelling"
@echo "--> Checking the spelling"
@which misspell 2>/dev/null ; if [ $$? -eq 1 ]; then \
go get -u github.com/client9/misspell/cmd/misspell; \
fi
......
......@@ -128,6 +128,12 @@ func getCommandLineOptions() []cli.Flag {
Name: optName,
Usage: usage,
})
case reflect.Int:
flags = append(flags, cli.IntFlag{
Name: optName,
Usage: usage,
EnvVar: envName,
})
case reflect.Int64:
switch t.String() {
case "time.Duration":
......@@ -170,6 +176,8 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) {
reflect.ValueOf(config).Elem().FieldByName(field.Name).SetString(cx.String(name))
case reflect.Slice:
reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.StringSlice(name)))
case reflect.Int:
reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.Int(name)))
case reflect.Int64:
switch field.Type.String() {
case "time.Duration":
......
......@@ -41,6 +41,7 @@ func newDefaultConfig() *Config {
OAuthURI: "/oauth",
OpenIDProviderTimeout: 30 * time.Second,
PreserveHost: false,
RequestIDHeader: "X-Request-ID",
ResponseHeaders: make(map[string]string),
SecureCookie: true,
ServerIdleTimeout: 120 * time.Second,
......
......@@ -178,9 +178,13 @@ type Config struct {
Headers map[string]string `json:"headers" yaml:"headers" usage:"custom headers to the upstream request, key=value"`
// PreserveHost preserves the host header of the proxied request in the upstream request
PreserveHost bool `json:"preserve-host" yaml:"preserve-host" usage:"preserve the host header of the proxied request in the upstream request"`
// RequestIDHeader is the header name for request ids
RequestIDHeader string `json:"request-id-header" yaml:"request-id-header" usage:"the http header name for request id" env:"REQUEST_ID_HEADER"`
// ResponseHeader is a map of response headers to add to the response
ResponseHeaders map[string]string `json:"response-headers" yaml:"response-headers" usage:"custom headers to added to the http response key=value"`
// EnableRequestID indicates the proxy should add request id if none if found
EnableRequestID bool `json:"enable-request-id" yaml:"enable-request-id" usage:"indicates we should add a request id if none found" env:"ENABLE_REQUEST_ID"`
// EnableLogoutRedirect indicates we should redirect to the identity provider for logging out
EnableLogoutRedirect bool `json:"enable-logout-redirect" yaml:"enable-logout-redirect" usage:"indicates we should redirect to the identity provider for logging out"`
// EnableDefaultDeny indicates we should deny by default all requests
......
......@@ -66,6 +66,19 @@ func entrypointMiddleware(next http.Handler) http.Handler {
})
}
// requestIDMiddleware is responsible for adding a request id if none found
func (r *oauthProxy) requestIDMiddleware(header string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if v := req.Header.Get(header); v == "" {
req.Header.Set(header, randomUUID())
}
next.ServeHTTP(w, req)
})
}
}
// loggingMiddleware is a custom http logger
func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
......
......@@ -60,9 +60,10 @@ type fakeRequest struct {
ExpectedContentContains string
ExpectedCookies map[string]string
ExpectedHeaders map[string]string
ExpectedProxyHeaders map[string]string
ExpectedLocation string
ExpectedNoProxyHeaders []string
ExpectedProxy bool
ExpectedProxyHeaders map[string]string
}
type fakeProxy struct {
......@@ -230,9 +231,20 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) {
if c.ExpectedProxyHeaders != nil && len(c.ExpectedProxyHeaders) > 0 {
for k, v := range c.ExpectedProxyHeaders {
headers := upstream.Headers
switch v {
case "":
assert.NotEmpty(t, headers.Get(k), "case %d, expected the proxy header: %s to exist", i, k)
default:
assert.Equal(t, v, headers.Get(k), "case %d, expected proxy header %s=%s, got: %s", i, k, v, headers.Get(k))
}
}
}
if len(c.ExpectedNoProxyHeaders) > 0 {
for _, k := range c.ExpectedNoProxyHeaders {
assert.Empty(t, upstream.Headers.Get(k), "case %d, header: %s was not expected to exist", i, k)
}
}
if c.ExpectedContent != "" {
e := string(resp.Body())
assert.Equal(t, c.ExpectedContent, e, "case %d, expected content: %s, got: %s", i, c.ExpectedContent, e)
......
......@@ -162,6 +162,11 @@ func (r *oauthProxy) createReverseProxy() error {
engine.MethodNotAllowed(emptyHandler)
engine.NotFound(emptyHandler)
engine.Use(middleware.Recoverer)
// @check if the request tracking id middleware is enabled
if r.config.EnableRequestID {
engine.Use(r.requestIDMiddleware(r.config.RequestIDHeader))
}
// @step: enable the entrypoint middleware
engine.Use(entrypointMiddleware)
if r.config.EnableLogging {
......
......@@ -147,6 +147,43 @@ func TestForbiddenTemplate(t *testing.T) {
newFakeProxy(cfg).RunTests(t, requests)
}
func TestRequestIDHeader(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableRequestID = true
requests := []fakeRequest{
{
URI: "/auth_all/test",
HasLogin: true,
ExpectedProxy: true,
Redirects: true,
ExpectedHeaders: map[string]string{
"X-Request-ID": "",
},
ExpectedCode: http.StatusOK,
},
}
newFakeProxy(c).RunTests(t, requests)
}
func TestAuthTokenHeaderDisabled(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableTokenHeader = false
p := newFakeProxy(c)
token := newTestToken(p.idp.getLocation())
signed, _ := p.idp.signToken(token.claims)
requests := []fakeRequest{
{
URI: "/auth_all/test",
RawToken: signed.Encode(),
ExpectedNoProxyHeaders: []string{"X-Auth-Token"},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
},
}
p.RunTests(t, requests)
}
func TestAudienceHeader(t *testing.T) {
c := newFakeKeycloakConfig()
c.NoRedirects = false
......@@ -371,27 +408,6 @@ func TestAuthTokenHeaderEnabled(t *testing.T) {
p.RunTests(t, requests)
}
func TestAuthTokenHeaderDisabled(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableTokenHeader = false
p := newFakeProxy(c)
token := newTestToken(p.idp.getLocation())
signed, _ := p.idp.signToken(token.claims)
requests := []fakeRequest{
{
URI: "/auth_all/test",
RawToken: signed.Encode(),
ExpectedProxyHeaders: map[string]string{
"X-Auth-Token": "",
},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
},
}
p.RunTests(t, requests)
}
func TestDisableAuthorizationCookie(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableAuthorizationCookies = false
......
......@@ -28,6 +28,7 @@ import (
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"net"
"net/http"
"net/url"
......@@ -77,6 +78,55 @@ func getRequestHostURL(r *http.Request) string {
return fmt.Sprintf("%s://%s", scheme, hostname)
}
const (
letterBytes = "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz0123456789"
letterIdxBits = 6
letterIdxMask = 1<<letterIdxBits - 1
letterIdxMax = 63 / letterIdxBits
)
var randomSource = mrand.NewSource(time.Now().UnixNano())
// randomBytes returns a random array of bytes
// @note: code taken from https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-golang
func randomBytes(n int) []byte {
b := make([]byte, n)
for i, cache, remain := n-1, randomSource.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = randomSource.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return b
}
// randomString returns a random string of x length
func randomString(length int) string {
return string(randomBytes(length))
}
// randomUUID returns a uuid from the random string
func randomUUID() string {
uuid := make([]byte, 36)
r := randomBytes(32)
i := 0
for x := range []int{8, 4, 4, 4, 12} {
copy(uuid, r[i:i+x])
if x != 12 {
copy(uuid, []byte("-"))
i = i + x
}
}
return string(uuid)
}
// readConfigFile reads and parses the configuration file
func readConfigFile(filename string, config *Config) error {
content, err := ioutil.ReadFile(filename)
......
......@@ -26,6 +26,7 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
......@@ -64,6 +65,48 @@ func TestDecodeKeyPairs(t *testing.T) {
}
}
func TestRandom(t *testing.T) {
s := randomBytes(6)
assert.NotEmpty(t, s)
assert.Equal(t, 6, len(s))
}
func TestRandomString(t *testing.T) {
s := randomString(6)
assert.NotEmpty(t, s)
assert.Equal(t, 6, len(s))
}
func TestRandomUUID(t *testing.T) {
s := randomUUID()
assert.NotEmpty(t, s)
assert.Equal(t, 36, len(s))
}
func BenchmarkRandomBytes36(b *testing.B) {
for n := 0; n < b.N; n++ {
randomString(36)
}
}
func BenchmarkRandomString36(b *testing.B) {
for n := 0; n < b.N; n++ {
randomString(36)
}
}
func BenchmarkUUID(b *testing.B) {
for n := 0; n < b.N; n++ {
uuid.New()
}
}
func BenchmarkRandomUUID(b *testing.B) {
for n := 0; n < b.N; n++ {
randomUUID()
}
}
func TestDefaultTo(t *testing.T) {
cs := []struct {
Value string
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment