diff --git a/handlers.go b/handlers.go index ea7ca661b21832d4c1b5fef81e720ed66bad442c..b9cdb65cb6621c8e36d072e0eb68b4da49b6c720 100644 --- a/handlers.go +++ b/handlers.go @@ -31,7 +31,7 @@ import ( // // The logic is broken into four handlers just to simplify the code // -// a) entrypointHandler checks if the the uri requires authentication +// a) entryPointHandler checks if the the uri requires authentication // b) authenticationHandler verifies the access token // c) admissionHandler verifies that the token is authorized to access to uri resource // c) proxyHandler is responsible for handling the reverse proxy to the upstream endpoint @@ -55,9 +55,9 @@ func (r *KeycloakProxy) loggingHandler() gin.HandlerFunc { "client_ip": cx.ClientIP(), "method": cx.Request.Method, "status": cx.Writer.Status(), - "path": cx.Request.RequestURI, + "path": cx.Request.URL.Path, "latency": latency.String(), - }).Infof("[%d] |%s| |%13v| %-5s %s", cx.Writer.Status(), cx.ClientIP(), latency, cx.Request.Method, cx.Request.URL.Path) + }).Infof("[%d] |%s| |%10v| %-5s %s", cx.Writer.Status(), cx.ClientIP(), latency, cx.Request.Method, cx.Request.URL.Path) } } @@ -88,37 +88,34 @@ func (r *KeycloakProxy) securityHandler() gin.HandlerFunc { // // entrypointHandler checks to see if the request requires authentication // -func (r *KeycloakProxy) entrypointHandler() gin.HandlerFunc { +func (r *KeycloakProxy) entryPointHandler() gin.HandlerFunc { return func(cx *gin.Context) { - // @@TODO need to fix this login - // step: ensure we don't block oauth - if strings.HasPrefix(cx.Request.RequestURI, oauthURL) { - if cx.Request.RequestURI != callbackURL && cx.Request.RequestURI != authorizationURL { - log.WithFields(log.Fields{"uri": cx.Request.RequestURI}).Warningf("client attempting to do something strange with oauth handlers") - - r.redirectToAuthorization(cx) - return - } + if strings.HasPrefix(cx.Request.URL.Path, oauthURL) { cx.Next() - return } - if !strings.HasPrefix(cx.Request.RequestURI, oauthURL) { - // step: check if authentication is required - gin doesn't support wildcard url, so we have have to use prefixes - for _, resource := range r.config.Resources { - if strings.HasPrefix(cx.Request.RequestURI, resource.URL) { - // step: has the resource been white listed? - if resource.WhiteListed { - break - } - // step: inject the resource into the context, saves us from doing this again - if containedIn(cx.Request.Method, resource.Methods) || containedIn("ANY", resource.Methods) { - cx.Set(cxEnforce, resource) - } + + // step: check if authentication is required - gin doesn't support wildcard url, so we have have to use prefixes + for _, resource := range r.config.Resources { + if strings.HasPrefix(cx.Request.URL.Path, resource.URL) { + // step: has the resource been white listed? + if resource.WhiteListed { break } + // step: inject the resource into the context, saves us from doing this again + if containedIn(cx.Request.Method, resource.Methods) || containedIn("ANY", resource.Methods) { + cx.Set(cxEnforce, resource) + } + break } } + // step: pass into the authentication and admission handlers + cx.Next() + + // step: check the request has not been aborted and if not, proxy request + if !cx.IsAborted() { + r.proxyHandler(cx) + } } } @@ -240,6 +237,7 @@ func (r *KeycloakProxy) authenticationHandler() gin.HandlerFunc { } cx.Next() + } } @@ -258,7 +256,6 @@ func (r *KeycloakProxy) admissionHandler() gin.HandlerFunc { // step: if authentication is required on this, grab the resource spec ur, found := cx.Get(cxEnforce) if !found { - cx.Next() return } @@ -342,57 +339,53 @@ func (r *KeycloakProxy) admissionHandler() gin.HandlerFunc { "expires": identity.expiresAt.Sub(time.Now()).String(), "bearer": identity.bearerToken, }).Debugf("resource access permitted: %s", cx.Request.RequestURI) - - cx.Next() } } // // proxyHandler is responsible to proxy the requests on to the upstream endpoint // -func (r *KeycloakProxy) proxyHandler() gin.HandlerFunc { - return func(cx *gin.Context) { - // step: double check, if enforce is true and no user context it's a internal error - if _, found := cx.Get(cxEnforce); found { - if _, found := cx.Get(userContextName); !found { - log.Errorf("no user context found for a secure request") - cx.AbortWithStatus(http.StatusInternalServerError) - return - } +func (r *KeycloakProxy) proxyHandler(cx *gin.Context) { + // step: double check, if enforce is true and no user context it's a internal error + if _, found := cx.Get(cxEnforce); found { + if _, found := cx.Get(userContextName); !found { + log.Errorf("no user context found for a secure request") + cx.AbortWithStatus(http.StatusInternalServerError) + return } + } - // step: retrieve the user context - if identity, found := cx.Get(userContextName); found { - id := identity.(*userContext) - cx.Request.Header.Add("X-Auth-UserId", id.id) - cx.Request.Header.Add("X-Auth-Subject", id.preferredName) - cx.Request.Header.Add("X-Auth-Username", id.name) - cx.Request.Header.Add("X-Auth-Email", id.email) - cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String()) - cx.Request.Header.Add("X-Auth-Token", id.token.Encode()) - cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ",")) - } + // step: retrieve the user context + if identity, found := cx.Get(userContextName); found { + id := identity.(*userContext) + cx.Request.Header.Add("X-Auth-UserId", id.id) + cx.Request.Header.Add("X-Auth-Subject", id.preferredName) + cx.Request.Header.Add("X-Auth-Username", id.name) + cx.Request.Header.Add("X-Auth-Email", id.email) + cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String()) + cx.Request.Header.Add("X-Auth-Token", id.token.Encode()) + cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ",")) + } - // step: add the default headers - cx.Request.Header.Set("X-Forwarded-For", cx.Request.RemoteAddr) - cx.Request.Header.Set("X-Forwarded-Agent", "keycloak-proxy") + // step: add the default headers + cx.Request.Header.Set("X-Forwarded-For", cx.Request.RemoteAddr) + cx.Request.Header.Set("X-Forwarded-Agent", "keycloak-proxy") - // step: is this connection upgrading? - if isUpgradedConnection(cx.Request) { - log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade)) - if err := r.tryUpdateConnection(cx); err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection") - - cx.AbortWithStatus(http.StatusInternalServerError) - return - } - cx.Abort() + // step: is this connection upgrading? + if isUpgradedConnection(cx.Request) { + log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade)) + if err := r.tryUpdateConnection(cx); err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection") + cx.AbortWithStatus(http.StatusInternalServerError) return } + cx.Abort() - r.proxy.ServeHTTP(cx.Writer, cx.Request) + return } + + r.proxy.ServeHTTP(cx.Writer, cx.Request) } // --- @@ -409,6 +402,10 @@ func (r *KeycloakProxy) oauthAuthorizationHandler(cx *gin.Context) { return } + log.WithFields(log.Fields{ + "client_ip": cx.ClientIP(), + }).Infof("incoming authorization request") + // step: grab the oauth client oac, err := r.openIDClient.OAuthClient() if err != nil { diff --git a/handlers_test.go b/handlers_test.go index e024783fa38bf1c0e34ae5e48deec33acca94847..5fa9aef2c490f4469ffadc984e8cec3fec04a236 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -40,7 +40,7 @@ func TestEntrypointHandlerSecure(t *testing.T) { }, }) - handler := proxy.entrypointHandler() + handler := proxy.entryPointHandler() tests := []struct { Context *gin.Context @@ -82,7 +82,7 @@ func TestEntrypointMethods(t *testing.T) { }, }) - handler := proxy.entrypointHandler() + handler := proxy.entryPointHandler() tests := []struct { Context *gin.Context @@ -122,7 +122,7 @@ func TestEntrypointWhiteListing(t *testing.T) { Methods: []string{"ANY"}, }, }) - handler := proxy.entrypointHandler() + handler := proxy.entryPointHandler() tests := []struct { Context *gin.Context @@ -148,7 +148,7 @@ func TestEntrypointWhiteListing(t *testing.T) { func TestEntrypointHandler(t *testing.T) { proxy := newFakeKeycloakProxy(t) - handler := proxy.entrypointHandler() + handler := proxy.entryPointHandler() tests := []struct { Context *gin.Context @@ -247,7 +247,7 @@ func TestAdmissionHandlerRoles(t *testing.T) { for i, c := range tests { // step: find the resource and inject into the context for _, r := range proxy.config.Resources { - if strings.HasPrefix(c.Context.Request.RequestURI, r.URL) { + if strings.HasPrefix(c.Context.Request.URL.Path, r.URL) { c.Context.Set(cxEnforce, r) break } diff --git a/resource.go b/resource.go index 55f440484d4160e17666dd62b8148319696f5860..d8ad0a2e56f769f5c62f45e9c8dde33b78f1e768 100644 --- a/resource.go +++ b/resource.go @@ -30,6 +30,10 @@ func (r *Resource) isValid() error { r.RolesAllowed = make([]string, 0) } + if strings.HasPrefix(r.URL, oauthURL) { + return fmt.Errorf("this is used by the oauth handlers") + } + // step: check we have a if r.URL == "" { return fmt.Errorf("resource does not have url") @@ -72,7 +76,7 @@ func (r Resource) String() string { if len(r.Methods) <= 0 { methods = "ANY" } else { - roles = strings.Join(r.Methods, ",") + methods = strings.Join(r.Methods, ",") } return fmt.Sprintf("uri: %s, methods: %s, required: %s", r.URL, methods, roles) diff --git a/server.go b/server.go index 5499cc61189943eec84ad28f17c9e59c457bcc39..499dd8e102b6ae1cb1f6cd70fcf0319d498fbb04 100644 --- a/server.go +++ b/server.go @@ -18,7 +18,6 @@ package main import ( "fmt" "net/http" - "net/http/httputil" "net/url" "os" "sync" @@ -38,11 +37,15 @@ type KeycloakProxy struct { // the oidc client openIDClient *oidc.Client // the proxy client - proxy *httputil.ReverseProxy + proxy reverseProxy // the upstream endpoint upstreamURL *url.URL } +type reverseProxy interface { + ServeHTTP(rw http.ResponseWriter, req *http.Request) +} + // newKeycloakProxy create's a new keycloak proxy from configuration func newKeycloakProxy(cfg *Config) (*KeycloakProxy, error) { // step: set the logging level @@ -112,15 +115,23 @@ func newKeycloakProxy(cfg *Config) (*KeycloakProxy, error) { router.Use(service.securityHandler()) } - router.Use(service.entrypointHandler(), service.authenticationHandler(), service.admissionHandler()) + // step: add the routing router.GET(authorizationURL, service.oauthAuthorizationHandler) router.GET(callbackURL, service.oauthCallbackHandler) - - router.Use(service.proxyHandler()) + router.GET(healthURL, service.healthHandler) + router.Use(service.entryPointHandler(), service.authenticationHandler(), service.admissionHandler()) return service, nil } +func (r *KeycloakProxy) abortAll() gin.HandlerFunc { + return func(cx *gin.Context) { + fmt.Println("HELLO") + cx.Next() + cx.Abort() + } +} + // initializeTemplates loads the custom template func (r *KeycloakProxy) initializeTemplates() { var list []string diff --git a/server_test.go b/server_test.go index cb753eb51fe363e2f778ea9bb333e5d457761c35..7ac22d26998b915d6a0b7ebe708a0422e72fa7d4 100644 --- a/server_test.go +++ b/server_test.go @@ -94,6 +94,7 @@ func newFakeKeycloakProxy(t *testing.T) *KeycloakProxy { }, }, }, + proxy: new(fakeReverseProxy), } return kc @@ -160,14 +161,19 @@ func newFakeGinContext(method, uri string) *gin.Context { URL: &url.URL{ Scheme: "http", Host: "127.0.0.1", - Path: "uri", + Path: uri, }, + Header: make(http.Header, 0), RemoteAddr: "127.0.0.1:8989", }, Writer: newFakeResponse(), } } +type fakeReverseProxy struct{} + +func (r fakeReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {} + type fakeResponse struct { size int status int diff --git a/utils.go b/utils.go index ef435697f9658cf970764193d930b690b95a5c06..b09d1f26fe4526cf536d873dfcf4512aa47c6003 100644 --- a/utils.go +++ b/utils.go @@ -156,7 +156,7 @@ func decodeKeyPairs(list []string) (map[string]string, error) { } // initializeReverseProxy create a reverse http proxy from the upstream -func initializeReverseProxy(upstream *url.URL) (*httputil.ReverseProxy, error) { +func initializeReverseProxy(upstream *url.URL) (reverseProxy, error) { proxy := httputil.NewSingleHostReverseProxy(upstream) // step: we don't care about the cert verification here proxy.Transport = &http.Transport{