diff --git a/Dockerfile b/Dockerfile index a2fbdf682a8320054ec4543c806275610be1c898..b2472f690c3a70b32b929c26c31e77a21cda9324 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,9 @@ FROM alpine:3.6 -MAINTAINER Rohith <gambol99@gmail.com> +MAINTAINER Rohith Jayawardene <gambol99@gmail.com> +LABEL Name=keycloak-proxy \ + Release=https://github.com/gambol99/keycloak-proxy \ + Url=https://github.com/gambol99/keycloak-proxy \ + Help=https://github.com/gambol99/keycloak-proxy/issues RUN apk add ca-certificates --update diff --git a/Makefile b/Makefile index 882a752ec37bc0071687f05898e5f5e144e5f45c..3e3a68ad12f21038447f9e392786170f34d3cf53 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ golang: @echo "--> Go Version" @go version -build: golang deps +build: golang @echo "--> Compiling the project" @mkdir -p bin go build -ldflags "${LFLAGS}" -o bin/${NAME} diff --git a/README.md b/README.md index d3d847db14fc984393ee7f2d451860d072c2d782..1900f006e7d03d3c6490367c683fa194951ad557 100644 --- a/README.md +++ b/README.md @@ -225,7 +225,7 @@ By default all requests will be proxyed on to the upstream, if you wish to ensur --resource=uri=/* # note, by default unless specified the methods is assumed to be 'any|ANY' ``` -Note the HTTP routing rules following the guidelines from [echo](https://echo.labstack.com/guide/routing). Its also worth nothing the ordering of the resource do not matter, the router will handle that for you. +Note the HTTP routing rules following the guidelines from [chi](https://github.com/go-chi/chi#router-design). Its also worth nothing the ordering of the resource do not matter, the router will handle that for you. #### **Google OAuth** diff --git a/doc.go b/doc.go index d7c9d9043c7ac9cd242571179e79dc617bcdd513..ff0522c5f2b3e4a693073cc58ca1216689060e4b 100644 --- a/doc.go +++ b/doc.go @@ -38,13 +38,12 @@ const ( email = "gambol99@gmail.com" description = "is a proxy using the keycloak service for auth and authorization" - headerUpgrade = "Upgrade" - userContextName = "identity" - revokeContextName = "revoke" authorizationHeader = "Authorization" - versionHeader = "X-Auth-Proxy-Version" + contextScopeName = "context.scope.name" envPrefix = "PROXY_" + headerUpgrade = "Upgrade" httpSchema = "http" + versionHeader = "X-Auth-Proxy-Version" oauthURL = "/oauth" authorizationURL = "/authorize" @@ -55,6 +54,7 @@ const ( logoutURL = "/logout" metricsURL = "/metrics" tokenURL = "/token" + debugURL = "/debug/pprof" claimPreferredName = "preferred_username" claimAudience = "aud" @@ -63,6 +63,15 @@ const ( claimResourceRoles = "roles" ) +const ( + headerXForwardedFor = "X-Forwarded-For" + headerXForwardedProto = "X-Forwarded-Proto" + headerXForwardedProtocol = "X-Forwarded-Protocol" + headerXForwardedSsl = "X-Forwarded-Ssl" + headerXRealIP = "X-Real-IP" + headerXRequestID = "X-Request-ID" +) + var ( // ErrSessionNotFound no session found in the request ErrSessionNotFound = errors.New("authentication session not found") @@ -254,6 +263,14 @@ func getVersion() string { return version } +// RequestScope is a request level context scope passed between middleware +type RequestScope struct { + // AccessDenied indicates the request should not be proxied on + AccessDenied bool + // Identity is the user Identity of the request + Identity *userContext +} + // storage is used to hold the offline refresh token, assuming you don't want to use // the default practice of a encrypted cookie type storage interface { diff --git a/forwarding.go b/forwarding.go index 594784bcbee14f37f7cfe855711941b9092ccc35..e7669fff8aaf640821c6388a2158d9ff7e93f23f 100644 --- a/forwarding.go +++ b/forwarding.go @@ -22,50 +22,50 @@ import ( "github.com/gambol99/go-oidc/jose" "github.com/gambol99/go-oidc/oidc" - "github.com/labstack/echo" "go.uber.org/zap" ) // proxyMiddleware is responsible for handles reverse proxy request to the upstream endpoint -func (r *oauthProxy) proxyMiddleware() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - next(cx) - // refuse to proxy - if found := cx.Get(revokeContextName); found != nil { - return nil +func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + next.ServeHTTP(w, req) + + // step: retrieve the request scope + scope := req.Context().Value(contextScopeName) + if scope != nil { + sc := scope.(*RequestScope) + if sc.AccessDenied { + return } + } - // is this connection upgrading? - if isUpgradedConnection(cx.Request()) { - r.log.Debug("upgrading the connnection", zap.String("client_ip", cx.RealIP())) - if err := tryUpdateConnection(cx.Request(), cx.Response().Writer, r.endpoint); err != nil { - r.log.Error("failed to upgrade connection", zap.Error(err)) - cx.NoContent(http.StatusInternalServerError) - return nil - } - return nil - } - // add any custom headers to the request - for k, v := range r.config.Headers { - cx.Request().Header.Set(k, v) + if isUpgradedConnection(req) { + r.log.Debug("upgrading the connnection", zap.String("client_ip", req.RemoteAddr)) + if err := tryUpdateConnection(req, w, r.endpoint); err != nil { + r.log.Error("failed to upgrade connection", zap.Error(err)) + w.WriteHeader(http.StatusInternalServerError) + return } + return + } - // By default goproxy only provides a forwarding proxy, thus all requests have to be absolute - // and we must update the host headers - cx.Request().URL.Host = r.endpoint.Host - cx.Request().URL.Scheme = r.endpoint.Scheme - cx.Request().Host = r.endpoint.Host + // add any custom headers to the request + for k, v := range r.config.Headers { + req.Header.Set(k, v) + } - cx.Request().Header.Add("X-Forwarded-For", cx.RealIP()) - cx.Request().Header.Set("X-Forwarded-Host", cx.Request().URL.Host) - cx.Request().Header.Set("X-Forwarded-Proto", cx.Request().Header.Get("X-Forwarded-Proto")) + // By default goproxy only provides a forwarding proxy, thus all requests have to be absolute + // and we must update the host headers + req.URL.Host = r.endpoint.Host + req.URL.Scheme = r.endpoint.Scheme + req.Host = r.endpoint.Host - r.upstream.ServeHTTP(cx.Response(), cx.Request()) + req.Header.Add("X-Forwarded-For", realIP(req)) + req.Header.Set("X-Forwarded-Host", req.URL.Host) + req.Header.Set("X-Forwarded-Proto", req.Header.Get("X-Forwarded-Proto")) - return nil - } - } + r.upstream.ServeHTTP(w, req) + }) } // forwardProxyHandler is responsible for signing outbound requests @@ -74,7 +74,6 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { if err != nil { r.log.Fatal("failed to create oauth client", zap.Error(err)) } - // the loop state var state struct { // the access token diff --git a/glide.lock b/glide.lock index fe1cab3b5d46d327429634a481c8cd9d4e055b75..aff63b58bbb80e085a1df328d134878f195154f0 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: b61fcd2b523ddfcacf31c69e7f46b2fd9286d0c6555029efe8a45de3552d25b2 -updated: 2017-06-12T14:55:48.532550213+01:00 +hash: 94f0f660e0bc53763d820d0f3e951a1ef67c5e821abdd8c38b88fbfae5be5fb6 +updated: 2017-06-25T01:22:53.985873813+01:00 imports: - name: github.com/armon/go-proxyproto version: 609d6338d3a76ec26ac3fe7045a164d9a58436e7 @@ -15,12 +15,6 @@ imports: - health - httputil - timeutil -- name: github.com/davecgh/go-spew - version: 5215b55f46b2b919f50a1df0eaa5886afe4e3b3d - subpackages: - - spew -- name: github.com/dgrijalva/jwt-go - version: 2268707a8f0843315e2004ee4f1d021dc08baedf - name: github.com/fsnotify/fsnotify version: fd9ec7deca8bf46ecd2a795baaacf2b3a9be1197 - name: github.com/gambol99/go-oidc @@ -33,37 +27,24 @@ imports: - oidc - name: github.com/gambol99/goproxy version: e713e5909438245be49ef559e74dd904833ebe90 -- name: github.com/go-resty/resty - version: 39c3db9c7bb4f9718ac143a83a924441521caf73 +- name: github.com/go-chi/chi + version: 18d990c0d1c023b05a3652d322ae36d8bdb62e07 + subpackages: + - middleware - name: github.com/golang/protobuf version: 0c1f6d65b5a189c2250d10e71a5506f06f9fa0a0 subpackages: - proto - name: github.com/jonboulle/clockwork version: ed104f61ea4877bea08af6f759805674861e968d -- name: github.com/labstack/echo - version: eac431df0dbad8ba6cc313fba37f1d4275c317e8 - subpackages: - - middleware -- name: github.com/labstack/gommon - version: e8995fb26e646187d33cff439b18609cfba23088 - subpackages: - - bytes - - color - - log - - random -- name: github.com/mattn/go-colorable - version: 9cbef7c35391cca05f15f8181dc0b18bc9736dbb -- name: github.com/mattn/go-isatty - version: 56b76bdf51f7708750eac80fa38b952bb9f32639 - name: github.com/matttproud/golang_protobuf_extensions version: c12348ce28de40eed0136aa2b644d0ee0650e56c subpackages: - pbutil -- name: github.com/pmezard/go-difflib - version: 792786c7400a136282c1664665ae0a8db921c6c2 +- name: github.com/pressly/chi + version: 18d990c0d1c023b05a3652d322ae36d8bdb62e07 subpackages: - - difflib + - middleware - name: github.com/prometheus/client_golang version: 488edd04dc224ba64c401747cd0a4b5f05dfb234 subpackages: @@ -84,23 +65,14 @@ imports: version: 0bcb03f4b4d0a9428594752bd2a3b9aa0a9d4bd4 - name: github.com/PuerkitoBio/urlesc version: 5bd2802263f21d8788851d5305584c82a5c75d7e -- name: github.com/stretchr/testify - version: f6abca593680b2315d2075e0f5e2a9751e3f431a - subpackages: - - assert - - require - - vendor/github.com/davecgh/go-spew/spew - - vendor/github.com/pmezard/go-difflib/difflib +- name: github.com/rs/cors + version: 8dd4211afb5d08dbb39a533b9bb9e4b486351df6 - name: github.com/unrolled/secure version: 4b41e52ab568cbfd31eda3612d98192da1575c77 - name: github.com/urfave/cli version: 0bdeddeeb0f650497d603c4ad7b20cfe685682f6 -- name: github.com/valyala/bytebufferpool - version: e746df99fe4a3986f4d4f79e13c1e0117ce9c2f7 -- name: github.com/valyala/fasttemplate - version: dcecefd839c4193db0d35b88ec65b4c12d360ab0 - name: go.uber.org/atomic - version: 0506d69f5564c56e25797bf7183c28921d4c6360 + version: 4e336646b2ef9fc6e47be8e21594178f98e5ebcf - name: go.uber.org/zap version: 54371c67da1bc746325e5582e48521a5db5d64ca subpackages: @@ -110,16 +82,9 @@ imports: - internal/exit - internal/multierror - zapcore -- name: golang.org/x/crypto - version: 453249f01cfeb54c3d549ddb75ff152ca243f9d8 - subpackages: - - acme - - acme/autocert - name: golang.org/x/net version: bc3663df0ac92f928d419e31e0d2af22e683a5a2 subpackages: - - context - - context/ctxhttp - idna - publicsuffix - name: golang.org/x/sys @@ -143,4 +108,19 @@ imports: - internal/pool - name: gopkg.in/yaml.v2 version: 49c95bdc21843256fb6c4e0d370a05f24a0bf213 -testImports: [] +testImports: +- name: github.com/davecgh/go-spew + version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 + subpackages: + - spew +- name: github.com/go-resty/resty + version: 39c3db9c7bb4f9718ac143a83a924441521caf73 +- name: github.com/pmezard/go-difflib + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + subpackages: + - difflib +- name: github.com/stretchr/testify + version: f6abca593680b2315d2075e0f5e2a9751e3f431a + subpackages: + - assert + - require diff --git a/glide.yaml b/glide.yaml index ef0561192bae2a0e1787bf14b0e2978b6798c2ba..626a1499730edac5017f27d80fd2b92f11f46579 100644 --- a/glide.yaml +++ b/glide.yaml @@ -1,144 +1,33 @@ package: github.com/gambol99/keycloak-proxy import: - package: github.com/PuerkitoBio/purell - version: 0bcb03f4b4d0a9428594752bd2a3b9aa0a9d4bd4 -- package: github.com/PuerkitoBio/urlesc - version: 5bd2802263f21d8788851d5305584c82a5c75d7e - package: github.com/armon/go-proxyproto - version: 609d6338d3a76ec26ac3fe7045a164d9a58436e7 -- package: github.com/beorn7/perks - version: 3ac7bf7a47d159a033b107610db8a1b6575507a4 - subpackages: - - quantile - package: github.com/boltdb/bolt - version: 144418e1475d8bf7abbdc48583500f1a20c62ea7 -- package: github.com/coreos/pkg - version: 447b7ec906e523386d9c53be15b55a8ae86ea944 - subpackages: - - health - - httputil - - timeutil -- package: github.com/davecgh/go-spew - version: 5215b55f46b2b919f50a1df0eaa5886afe4e3b3d - subpackages: - - spew -- package: github.com/dgrijalva/jwt-go - version: 2268707a8f0843315e2004ee4f1d021dc08baedf - package: github.com/fsnotify/fsnotify - version: fd9ec7deca8bf46ecd2a795baaacf2b3a9be1197 - package: github.com/gambol99/go-oidc - version: 2111f98a1397a35f1800f4c3c354a7abebbef75c subpackages: - - http - jose - - key - oauth2 - oidc - package: github.com/gambol99/goproxy - version: e713e5909438245be49ef559e74dd904833ebe90 -- package: github.com/go-resty/resty - version: 39c3db9c7bb4f9718ac143a83a924441521caf73 -- package: github.com/golang/protobuf - version: 0c1f6d65b5a189c2250d10e71a5506f06f9fa0a0 - subpackages: - - proto -- package: github.com/jonboulle/clockwork - version: ed104f61ea4877bea08af6f759805674861e968d -- package: github.com/labstack/echo - version: eac431df0dbad8ba6cc313fba37f1d4275c317e8 +- package: github.com/go-chi/chi subpackages: - middleware -- package: github.com/labstack/gommon - version: e8995fb26e646187d33cff439b18609cfba23088 - subpackages: - - bytes - - color - - log - - random -- package: github.com/mattn/go-colorable - version: 9cbef7c35391cca05f15f8181dc0b18bc9736dbb -- package: github.com/mattn/go-isatty - version: 56b76bdf51f7708750eac80fa38b952bb9f32639 -- package: github.com/matttproud/golang_protobuf_extensions - version: c12348ce28de40eed0136aa2b644d0ee0650e56c - subpackages: - - pbutil -- package: github.com/pmezard/go-difflib - version: 792786c7400a136282c1664665ae0a8db921c6c2 +- package: github.com/pressly/chi subpackages: - - difflib + - middleware - package: github.com/prometheus/client_golang - version: 488edd04dc224ba64c401747cd0a4b5f05dfb234 subpackages: - prometheus -- package: github.com/prometheus/client_model - version: fa8ad6fec33561be4280a8f0514318c79d7f6cb6 - subpackages: - - go -- package: github.com/prometheus/common - version: 3a184ff7dfd46b9091030bf2e56c71112b0ddb0e - subpackages: - - expfmt - - internal/bitbucket.org/ww/goautoneg - - model -- package: github.com/prometheus/procfs - version: abf152e5f3e97f2fafac028d2cc06c1feb87ffa5 -- package: github.com/stretchr/testify - version: f6abca593680b2315d2075e0f5e2a9751e3f431a - subpackages: - - assert - - require - - vendor/github.com/davecgh/go-spew/spew - - vendor/github.com/pmezard/go-difflib/difflib +- package: github.com/rs/cors - package: github.com/unrolled/secure - version: 4b41e52ab568cbfd31eda3612d98192da1575c77 - package: github.com/urfave/cli - version: 0bdeddeeb0f650497d603c4ad7b20cfe685682f6 -- package: github.com/valyala/bytebufferpool - version: e746df99fe4a3986f4d4f79e13c1e0117ce9c2f7 -- package: github.com/valyala/fasttemplate - version: dcecefd839c4193db0d35b88ec65b4c12d360ab0 -- package: go.uber.org/atomic - version: 0506d69f5564c56e25797bf7183c28921d4c6360 - package: go.uber.org/zap - version: 54371c67da1bc746325e5582e48521a5db5d64ca - subpackages: - - buffer - - internal/bufferpool - - internal/color - - internal/exit - - internal/multierror - - zapcore -- package: golang.org/x/crypto - version: 453249f01cfeb54c3d549ddb75ff152ca243f9d8 - subpackages: - - acme - - acme/autocert -- package: golang.org/x/net - version: bc3663df0ac92f928d419e31e0d2af22e683a5a2 - subpackages: - - context - - context/ctxhttp - - idna - - publicsuffix -- package: golang.org/x/sys - version: 833a04a10549a95dc34458c195cbad61bbb6cb4d - subpackages: - - unix -- package: golang.org/x/text - version: f28f36722d5ef2f9655ad3de1f248e3e52ad5ebd - subpackages: - - transform - - unicode/norm - - width -- package: gopkg.in/bsm/ratelimit.v1 - version: db14e161995a5177acef654cb0dd785e8ee8bc22 - package: gopkg.in/redis.v4 - version: 889409de38315d22b114fb5980f705e6fa48c6a2 - subpackages: - - internal - - internal/consistenthash - - internal/hashtag - - internal/pool - package: gopkg.in/yaml.v2 - version: 49c95bdc21843256fb6c4e0d370a05f24a0bf213 +testImport: +- package: github.com/go-resty/resty +- package: github.com/stretchr/testify + subpackages: + - assert + - require diff --git a/handlers.go b/handlers.go index 5268d932720b6327b2f39f49233d030f76b4d530..f5fae2519ba4e3fa52b927b6edada54477da7705 100644 --- a/handlers.go +++ b/handlers.go @@ -18,6 +18,7 @@ package main import ( "bytes" "encoding/base64" + "encoding/json" "errors" "fmt" "io/ioutil" @@ -30,25 +31,25 @@ import ( "time" "github.com/gambol99/go-oidc/oauth2" - "github.com/labstack/echo" + "github.com/pressly/chi" "go.uber.org/zap" ) // getRedirectionURL returns the redirectionURL for the oauth flow -func (r *oauthProxy) getRedirectionURL(cx echo.Context) string { +func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request) string { var redirect string switch r.config.RedirectionURL { case "": // need to determine the scheme, cx.Request.URL.Scheme doesn't have it, best way is to default // and then check for TLS scheme := "http" - if !cx.IsTLS() { + if req.TLS != nil { scheme = "https" } // @QUESTION: should I use the X-Forwarded-<header>?? .. redirect = fmt.Sprintf("%s://%s", - defaultTo(cx.Request().Header.Get("X-Forwarded-Proto"), scheme), - defaultTo(cx.Request().Header.Get("X-Forwarded-Host"), cx.Request().Host)) + defaultTo(req.Header.Get("X-Forwarded-Proto"), scheme), + defaultTo(req.Header.Get("X-Forwarded-Host"), req.Host)) default: redirect = r.config.RedirectionURL } @@ -56,56 +57,17 @@ func (r *oauthProxy) getRedirectionURL(cx echo.Context) string { return fmt.Sprintf("%s/oauth/callback", redirect) } -// oauthHandler is required due to the fact the echo router does not run middleware if no handler -// is found for a group https://github.com/labstack/echo/issues/856 -func (r *oauthProxy) oauthHandler(cx echo.Context) error { - handler := fmt.Sprintf("/%s", strings.TrimLeft(cx.Param("name"), "/")) - r.revokeProxy(cx) - switch cx.Request().Method { - case http.MethodGet: - switch handler { - case authorizationURL: - return r.oauthAuthorizationHandler(cx) - case callbackURL: - return r.oauthCallbackHandler(cx) - case expiredURL: - return r.expirationHandler(cx) - case healthURL: - return r.healthHandler(cx) - case logoutURL: - return r.logoutHandler(cx) - case tokenURL: - return r.tokenHandler(cx) - case metricsURL: - if r.config.EnableMetrics { - return r.proxyMetricsHandler(cx) - } - default: - return cx.NoContent(http.StatusNotFound) - } - case http.MethodPost: - switch handler { - case loginURL: - return r.loginHandler(cx) - default: - return cx.NoContent(http.StatusNotFound) - } - default: - return cx.NoContent(http.StatusMethodNotAllowed) - } - - return nil -} - // oauthAuthorizationHandler is responsible for performing the redirection to oauth provider -func (r *oauthProxy) oauthAuthorizationHandler(cx echo.Context) error { +func (r *oauthProxy) oauthAuthorizationHandler(w http.ResponseWriter, req *http.Request) { if r.config.SkipTokenVerification { - return cx.NoContent(http.StatusNotAcceptable) + w.WriteHeader(http.StatusNotAcceptable) + return } - client, err := r.getOAuthClient(r.getRedirectionURL(cx)) + client, err := r.getOAuthClient(r.getRedirectionURL(w, req)) if err != nil { r.log.Error("failed to retrieve the oauth client for authorization", zap.Error(err)) - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } // step: set the access type of the session @@ -114,44 +76,50 @@ func (r *oauthProxy) oauthAuthorizationHandler(cx echo.Context) error { accessType = "offline" } - authURL := client.AuthCodeURL(cx.QueryParam("state"), accessType, "") + authURL := client.AuthCodeURL(req.URL.Query().Get("state"), accessType, "") r.log.Debug("incoming authorization request from client address", zap.String("access_type", accessType), zap.String("auth_url", authURL), - zap.String("client_ip", cx.RealIP())) + zap.String("client_ip", req.RemoteAddr)) // step: if we have a custom sign in page, lets display that if r.config.hasCustomSignInPage() { model := make(map[string]string) model["redirect"] = authURL + w.WriteHeader(http.StatusOK) + r.Render(w, path.Base(r.config.SignInPage), mergeMaps(model, r.config.Tags)) - return cx.Render(http.StatusOK, path.Base(r.config.SignInPage), mergeMaps(model, r.config.Tags)) + return } - return r.redirectToURL(authURL, cx) + r.redirectToURL(authURL, w, req) } // oauthCallbackHandler is responsible for handling the response from oauth service -func (r *oauthProxy) oauthCallbackHandler(cx echo.Context) error { +func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Request) { if r.config.SkipTokenVerification { - return cx.NoContent(http.StatusNotAcceptable) + w.WriteHeader(http.StatusNotAcceptable) + return } // step: ensure we have a authorization code - code := cx.QueryParam("code") + code := req.URL.Query().Get("code") if code == "" { - return cx.NoContent(http.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) + return } - client, err := r.getOAuthClient(r.getRedirectionURL(cx)) + client, err := r.getOAuthClient(r.getRedirectionURL(w, req)) if err != nil { r.log.Error("unable to create a oauth2 client", zap.Error(err)) - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } resp, err := exchangeAuthenticationCode(client, code) if err != nil { r.log.Error("unable to exchange code for access token", zap.Error(err)) - return r.accessForbidden(cx) + r.accessForbidden(w, req) + return } // Flow: once we exchange the authorization code we parse the ID Token; we then check for a access token, @@ -160,7 +128,8 @@ func (r *oauthProxy) oauthCallbackHandler(cx echo.Context) error { token, identity, err := parseToken(resp.IDToken) if err != nil { r.log.Error("unable to parse id token for identity", zap.Error(err)) - return r.accessForbidden(cx) + r.accessForbidden(w, req) + return } access, id, err := parseToken(resp.AccessToken) if err == nil { @@ -173,7 +142,8 @@ func (r *oauthProxy) oauthCallbackHandler(cx echo.Context) error { // step: check the access token is valid if err = verifyToken(r.client, token); err != nil { r.log.Error("unable to verify the id token", zap.Error(err)) - return r.accessForbidden(cx) + r.accessForbidden(w, req) + return } accessToken := token.Encode() @@ -181,7 +151,8 @@ func (r *oauthProxy) oauthCallbackHandler(cx echo.Context) error { if r.config.EnableEncryptedToken { if accessToken, err = encodeText(accessToken, r.config.EncryptionKey); err != nil { r.log.Error("unable to encode the access token", zap.Error(err)) - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } } @@ -196,10 +167,11 @@ func (r *oauthProxy) oauthCallbackHandler(cx echo.Context) error { encrypted, err = encodeText(resp.RefreshToken, r.config.EncryptionKey) if err != nil { r.log.Error("failed to encrypt the refresh token", zap.Error(err)) - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } // drop in the access token - cookie expiration = access token - r.dropAccessTokenCookie(cx.Request(), cx.Response().Writer, accessToken, r.getAccessCookieExpiration(token, resp.RefreshToken)) + r.dropAccessTokenCookie(req, w, accessToken, r.getAccessCookieExpiration(token, resp.RefreshToken)) switch r.useStore() { case true: @@ -210,39 +182,39 @@ func (r *oauthProxy) oauthCallbackHandler(cx echo.Context) error { // notes: not all idp refresh tokens are readable, google for example, so we attempt to decode into // a jwt and if possible extract the expiration, else we default to 10 days if _, ident, err := parseToken(resp.RefreshToken); err != nil { - r.dropRefreshTokenCookie(cx.Request(), cx.Response().Writer, encrypted, time.Duration(240)*time.Hour) + r.dropRefreshTokenCookie(req, w, encrypted, time.Duration(240)*time.Hour) } else { - r.dropRefreshTokenCookie(cx.Request(), cx.Response().Writer, encrypted, time.Until(ident.ExpiresAt)) + r.dropRefreshTokenCookie(req, w, encrypted, time.Until(ident.ExpiresAt)) } } } else { - r.dropAccessTokenCookie(cx.Request(), cx.Response().Writer, accessToken, time.Until(identity.ExpiresAt)) + r.dropAccessTokenCookie(req, w, accessToken, time.Until(identity.ExpiresAt)) } // step: decode the state variable state := "/" - if cx.QueryParam("state") != "" { - decoded, err := base64.StdEncoding.DecodeString(cx.QueryParam("state")) + if req.URL.Query().Get("state") != "" { + decoded, err := base64.StdEncoding.DecodeString(req.URL.Query().Get("state")) if err != nil { r.log.Warn("unable to decode the state parameter", - zap.String("state", cx.QueryParam("state")), + zap.String("state", req.URL.Query().Get("state")), zap.Error(err)) } else { state = string(decoded) } } - return r.redirectToURL(state, cx) + r.redirectToURL(state, w, req) } // loginHandler provide's a generic endpoint for clients to perform a user_credentials login to the provider -func (r *oauthProxy) loginHandler(cx echo.Context) error { +func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { errorMsg, code, err := func() (string, int, error) { if !r.config.EnableLoginHandler { return "attempt to login when login handler is disabled", http.StatusNotImplemented, errors.New("login handler disabled") } - username := cx.Request().PostFormValue("username") - password := cx.Request().PostFormValue("password") + username := req.PostFormValue("username") + password := req.PostFormValue("password") if username == "" || password == "" { return "request does not have both username and password", http.StatusBadRequest, errors.New("no credentials") } @@ -265,53 +237,54 @@ func (r *oauthProxy) loginHandler(cx echo.Context) error { return "unable to decode the access token", http.StatusNotImplemented, err } - r.dropAccessTokenCookie(cx.Request(), cx.Response().Writer, token.AccessToken, time.Until(identity.ExpiresAt)) + r.dropAccessTokenCookie(req, w, token.AccessToken, time.Until(identity.ExpiresAt)) - cx.JSON(http.StatusOK, tokenResponse{ + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(tokenResponse{ IDToken: token.IDToken, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, ExpiresIn: token.Expires, Scope: token.Scope, - }) + }); err != nil { + return "", http.StatusInternalServerError, err + } return "", http.StatusOK, nil }() if err != nil { r.log.Error(errorMsg, - zap.String("client_ip", cx.RealIP()), + zap.String("client_ip", req.RemoteAddr), zap.Error(err)) - return cx.NoContent(code) + w.WriteHeader(code) } - - return nil } // emptyHandler is responsible for doing nothing -func emptyHandler(cx echo.Context) error { - return nil -} +func emptyHandler(w http.ResponseWriter, req *http.Request) {} // logoutHandler performs a logout // - if it's just a access token, the cookie is deleted // - if the user has a refresh token, the token is invalidated by the provider // - optionally, the user can be redirected by to a url -func (r *oauthProxy) logoutHandler(cx echo.Context) error { +func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { // the user can specify a url to redirect the back - redirectURL := cx.QueryParam("redirect") + redirectURL := req.URL.Query().Get("redirect") // step: drop the access token - user, err := r.getIdentity(cx.Request()) + user, err := r.getIdentity(req) if err != nil { - return cx.NoContent(http.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) + return } + // step: can either use the id token or the refresh token identityToken := user.token.Encode() - if refresh, _, err := r.retrieveRefreshToken(cx.Request(), user); err == nil { + if refresh, _, err := r.retrieveRefreshToken(req, user); err == nil { identityToken = refresh } - r.clearAllCookies(cx.Request(), cx.Response().Writer) + r.clearAllCookies(req, w) // step: check if the user has a state session and if so revoke it if r.useStore() { @@ -328,20 +301,20 @@ func (r *oauthProxy) logoutHandler(cx echo.Context) error { client, err := r.client.OAuthClient() if err != nil { r.log.Error("unable to retrieve the openid client", zap.Error(err)) - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } // step: add the authentication headers - // @TODO need to add the authenticated request to go-oidc encodedID := url.QueryEscape(r.config.ClientID) encodedSecret := url.QueryEscape(r.config.ClientSecret) // step: construct the url for revocation - request, err := http.NewRequest(http.MethodPost, revocationURL, - bytes.NewBufferString(fmt.Sprintf("refresh_token=%s", identityToken))) + request, err := http.NewRequest(http.MethodPost, revocationURL, bytes.NewBufferString(fmt.Sprintf("refresh_token=%s", identityToken))) if err != nil { r.log.Error("unable to construct the revocation request", zap.Error(err)) - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } // step: add the authentication headers and content-type request.SetBasicAuth(encodedID, encodedSecret) @@ -350,7 +323,7 @@ func (r *oauthProxy) logoutHandler(cx echo.Context) error { response, err := client.HttpClient().Do(request) if err != nil { r.log.Error("unable to post to revocation endpoint", zap.Error(err)) - return nil + return } // step: check the response @@ -366,47 +339,47 @@ func (r *oauthProxy) logoutHandler(cx echo.Context) error { } // step: should we redirect the user if redirectURL != "" { - return r.redirectToURL(redirectURL, cx) + r.redirectToURL(redirectURL, w, req) } - - return cx.NoContent(http.StatusOK) } // expirationHandler checks if the token has expired -func (r *oauthProxy) expirationHandler(cx echo.Context) error { - user, err := r.getIdentity(cx.Request()) +func (r *oauthProxy) expirationHandler(w http.ResponseWriter, req *http.Request) { + user, err := r.getIdentity(req) if err != nil { - return cx.NoContent(http.StatusUnauthorized) + w.WriteHeader(http.StatusUnauthorized) + return } + if user.isExpired() { - return cx.NoContent(http.StatusUnauthorized) + w.WriteHeader(http.StatusUnauthorized) + return } - - return cx.NoContent(http.StatusOK) + w.WriteHeader(http.StatusOK) } // tokenHandler display access token to screen -func (r *oauthProxy) tokenHandler(cx echo.Context) error { - user, err := r.getIdentity(cx.Request()) +func (r *oauthProxy) tokenHandler(w http.ResponseWriter, req *http.Request) { + user, err := r.getIdentity(req) if err != nil { - return cx.String(http.StatusBadRequest, fmt.Sprintf("unable to retrieve session, error: %s", err)) + w.WriteHeader(http.StatusBadRequest) + return } - cx.Response().Writer.Header().Set("Content-Type", "application/json") - - return cx.String(http.StatusOK, fmt.Sprintf("%s", user.token.Payload)) + w.Header().Set("Content-Type", "application/json") + w.Write(user.token.Payload) } // healthHandler is a health check handler for the service -func (r *oauthProxy) healthHandler(cx echo.Context) error { - cx.Response().Writer.Header().Set(versionHeader, getVersion()) - return cx.String(http.StatusOK, "OK\n") +func (r *oauthProxy) healthHandler(w http.ResponseWriter, req *http.Request) { + w.Header().Set(versionHeader, getVersion()) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK\n")) } // debugHandler is responsible for providing the pprof -func (r *oauthProxy) debugHandler(cx echo.Context) error { - r.revokeProxy(cx) - name := cx.Param("name") - switch cx.Request().Method { +func (r *oauthProxy) debugHandler(w http.ResponseWriter, req *http.Request) { + name := chi.URLParam(req, "name") + switch req.Method { case http.MethodGet: switch name { case "heap": @@ -416,40 +389,37 @@ func (r *oauthProxy) debugHandler(cx echo.Context) error { case "block": fallthrough case "threadcreate": - pprof.Handler(name).ServeHTTP(cx.Response().Writer, cx.Request()) + pprof.Handler(name).ServeHTTP(w, req) case "cmdline": - pprof.Cmdline(cx.Response().Writer, cx.Request()) + pprof.Cmdline(w, req) case "profile": - pprof.Profile(cx.Response().Writer, cx.Request()) + pprof.Profile(w, req) case "trace": - pprof.Trace(cx.Response().Writer, cx.Request()) + pprof.Trace(w, req) case "symbol": - pprof.Symbol(cx.Response().Writer, cx.Request()) + pprof.Symbol(w, req) default: - cx.NoContent(http.StatusNotFound) + w.WriteHeader(http.StatusNotFound) } case http.MethodPost: switch name { case "symbol": - pprof.Symbol(cx.Response().Writer, cx.Request()) + pprof.Symbol(w, req) default: - cx.NoContent(http.StatusNotFound) + w.WriteHeader(http.StatusNotFound) } } - - return nil } // proxyMetricsHandler forwards the request into the prometheus handler -func (r *oauthProxy) proxyMetricsHandler(cx echo.Context) error { +func (r *oauthProxy) proxyMetricsHandler(w http.ResponseWriter, req *http.Request) { if r.config.LocalhostMetrics { - if !net.ParseIP(cx.RealIP()).IsLoopback() { - return r.accessForbidden(cx) + if !net.ParseIP(realIP(req)).IsLoopback() { + r.accessForbidden(w, req) + return } } - r.metricsHandler.ServeHTTP(cx.Response().Writer, cx.Request()) - - return nil + r.metricsHandler.ServeHTTP(w, req) } // retrieveRefreshToken retrieves the refresh token from store or cookie @@ -468,3 +438,8 @@ func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) token, err = decodeText(token, r.config.EncryptionKey) return } + +func methodNotAllowHandlder(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + w.Write(nil) +} diff --git a/handlers_test.go b/handlers_test.go index edf304a8c33302f8075555fe197ff241e7160544..80f6a47c26309a3295addc3d849dae883210793b 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -67,7 +67,7 @@ func TestOauthRequestNotProxying(t *testing.T) { requests := []fakeRequest{ {URI: "/oauth/test"}, {URI: "/oauth/..//oauth/test/"}, - {URI: "/oauth/expired", Method: http.MethodPost, ExpectedCode: http.StatusNotFound}, + {URI: "/oauth/expired", Method: http.MethodPost, ExpectedCode: http.StatusMethodNotAllowed}, {URI: "/oauth/expiring", Method: http.MethodPost}, {URI: "/oauth%2F///../test%2F%2Foauth"}, } @@ -79,7 +79,7 @@ func TestLoginHandlerDisabled(t *testing.T) { c.EnableLoginHandler = false requests := []fakeRequest{ {URI: oauthURL + loginURL, Method: http.MethodPost, ExpectedCode: http.StatusNotImplemented}, - {URI: oauthURL + loginURL, ExpectedCode: http.StatusNotFound}, + {URI: oauthURL + loginURL, ExpectedCode: http.StatusMethodNotAllowed}, } newFakeProxy(c).RunTests(t, requests) } @@ -280,7 +280,7 @@ func TestCallbackURL(t *testing.T) { { URI: oauthURL + callbackURL, Method: http.MethodPost, - ExpectedCode: http.StatusNotFound, + ExpectedCode: http.StatusMethodNotAllowed, }, { URI: oauthURL + callbackURL, diff --git a/middleware.go b/middleware.go index 3299c0c9c8b0473c613ebdbf97cc235ed9ffef55..f16a65844115f10b9d696fed538ff7e5bba240ae 100644 --- a/middleware.go +++ b/middleware.go @@ -16,6 +16,7 @@ limitations under the License. package main import ( + "context" "fmt" "net/http" "regexp" @@ -24,78 +25,64 @@ import ( "github.com/PuerkitoBio/purell" "github.com/gambol99/go-oidc/jose" - "github.com/labstack/echo" + "github.com/go-chi/chi/middleware" "github.com/prometheus/client_golang/prometheus" "github.com/unrolled/secure" "go.uber.org/zap" ) -const normalizeFlags purell.NormalizationFlags = purell.FlagRemoveDotSegments | purell.FlagRemoveDuplicateSlashes +const ( + // normalizeFlags is the options to purell + normalizeFlags purell.NormalizationFlags = purell.FlagRemoveDotSegments | purell.FlagRemoveDuplicateSlashes + // httpResponseName is the name of the http response hanlder + httpResponseName = "http.response" +) -// proxyRevokeMiddleware is just a helper to drop all requests proxying -func (r *oauthProxy) proxyRevokeMiddleware() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - r.revokeProxy(cx) - cx.NoContent(http.StatusForbidden) - return next(cx) - } - } -} +// entrypointMiddleware is custom filtering for incoming requests +func entrypointMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + keep := req.URL.Path + purell.NormalizeURL(req.URL, normalizeFlags) -// filterMiddleware is custom filtering for incoming requests -func (r *oauthProxy) filterMiddleware() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - // step: keep a copy of the original - keep := cx.Request().URL.Path - purell.NormalizeURL(cx.Request().URL, normalizeFlags) - // step: ensure we have a slash in the url - if !strings.HasPrefix(cx.Request().URL.Path, "/") { - cx.Request().URL.Path = "/" + cx.Request().URL.Path - } - cx.Request().RequestURI = cx.Request().URL.RawPath - cx.Request().URL.RawPath = cx.Request().URL.Path - // step: continue the flow - next(cx) - // step: place back the original uri for proxying request - cx.Request().URL.Path = keep - cx.Request().URL.RawPath = keep - cx.Request().RequestURI = keep - - return nil + // ensure we have a slash in the url + if !strings.HasPrefix(req.URL.Path, "/") { + req.URL.Path = "/" + req.URL.Path } - } + req.RequestURI = req.URL.RawPath + req.URL.RawPath = req.URL.Path + + // continue the flow + scope := &RequestScope{} + resp := middleware.NewWrapResponseWriter(w, 2) + next.ServeHTTP(resp, req.WithContext(context.WithValue(req.Context(), contextScopeName, scope))) + + // place back the original uri for proxying request + req.URL.Path = keep + req.URL.RawPath = keep + req.RequestURI = keep + }) } // loggingMiddleware is a custom http logger -func (r *oauthProxy) loggingMiddleware() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - start := time.Now() - next(cx) - latency := time.Since(start) - addr := cx.RealIP() - //msg := .Infof("[%d] |%s| |%10v| %-5s %s", cx.Response().Status, addr, latency, cx.Request().Method, cx.Request().URL.Path) - r.log.Info("client request", - zap.Int("response", cx.Response().Status), - zap.String("path", cx.Request().URL.Path), - zap.String("client_ip", addr), - zap.String("method", cx.Request().Method), - zap.Int("status", cx.Response().Status), - zap.Int64("bytes", cx.Response().Size), - zap.String("path", cx.Request().URL.Path), - zap.String("latency", latency.String())) - - return nil - } - } +func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + start := time.Now() + resp := w.(middleware.WrapResponseWriter) + next.ServeHTTP(resp, req) + addr := req.RemoteAddr + r.log.Info("client request", + zap.Duration("latency", time.Since(start)), + zap.Int("status", resp.Status()), + zap.Int("bytes", resp.BytesWritten()), + zap.String("client_ip", addr), + zap.String("method", req.Method), + zap.String("path", req.URL.Path)) + }) } // metricsMiddleware is responsible for collecting metrics -func (r *oauthProxy) metricsMiddleware() echo.MiddlewareFunc { - r.log.Info("enabled the service metrics middleware, available on", - zap.String("path", fmt.Sprintf("%s%s", oauthURL, metricsURL))) +func (r *oauthProxy) metricsMiddleware(next http.Handler) http.Handler { + r.log.Info("enabled the service metrics middleware, available on", zap.String("path", fmt.Sprintf("%s%s", oauthURL, metricsURL))) statusMetrics := prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -106,27 +93,30 @@ func (r *oauthProxy) metricsMiddleware() echo.MiddlewareFunc { ) prometheus.MustRegister(statusMetrics) - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - statusMetrics.WithLabelValues(fmt.Sprintf("%d", cx.Response().Status), cx.Request().Method).Inc() - return next(cx) - } - } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + resp := w.(middleware.WrapResponseWriter) + statusMetrics.WithLabelValues(fmt.Sprintf("%d", resp.Status()), req.Method).Inc() + + next.ServeHTTP(w, req) + }) } // authenticationMiddleware is responsible for verifying the access token -func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - clientIP := cx.RealIP() - - // step: grab the user identity from the request - user, err := r.getIdentity(cx.Request()) +func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + clientIP := req.RemoteAddr + // grab the user identity from the request + user, err := r.getIdentity(req) if err != nil { r.log.Error("no session found in request, redirecting for authorization", zap.Error(err)) - return r.redirectToAuthorization(cx) + next.ServeHTTP(w, req.WithContext(r.redirectToAuthorization(w, req))) + return } - cx.Set(userContextName, user) + // create the request scope + scope := req.Context().Value(contextScopeName).(*RequestScope) + scope.Identity = user + ctx := context.WithValue(req.Context(), contextScopeName, scope) // step: skip if we are running skip-token-verification if r.config.SkipTokenVerification { @@ -136,7 +126,9 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.Middlewar zap.String("client_ip", clientIP), zap.String("username", user.name), zap.String("expired_on", user.expiresAt.String())) - return r.redirectToAuthorization(cx) + + next.ServeHTTP(w, req.WithContext(r.redirectToAuthorization(w, req))) + return } } else { if err := verifyToken(r.client, user.token); err != nil { @@ -147,7 +139,9 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.Middlewar r.log.Error("access token failed verification", zap.String("client_ip", clientIP), zap.Error(err)) - return r.accessForbidden(cx) + + next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) + return } // step: check if we are refreshing the access tokens and if not re-auth @@ -156,7 +150,9 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.Middlewar zap.String("client_ip", clientIP), zap.String("email", user.name), zap.String("expired_on", user.expiresAt.String())) - return r.redirectToAuthorization(cx) + + next.ServeHTTP(w, req.WithContext(r.redirectToAuthorization(w, req))) + return } r.log.Info("accces token for user has expired, attemping to refresh the token", @@ -164,13 +160,15 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.Middlewar zap.String("email", user.email)) // step: check if the user has refresh token - refresh, encrypted, err := r.retrieveRefreshToken(cx.Request(), user) + refresh, encrypted, err := r.retrieveRefreshToken(req.WithContext(ctx), user) if err != nil { r.log.Error("unable to find a refresh token for user", zap.String("client_ip", clientIP), zap.String("email", user.email), zap.Error(err)) - return r.redirectToAuthorization(cx) + + next.ServeHTTP(w, req.WithContext(r.redirectToAuthorization(w, req))) + return } // attempt to refresh the access token @@ -182,11 +180,13 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.Middlewar zap.String("client_ip", clientIP), zap.String("email", user.email)) - r.clearAllCookies(cx.Request(), cx.Response().Writer) + r.clearAllCookies(req.WithContext(ctx), w) default: r.log.Error("failed to refresh the access token", zap.Error(err)) } - return r.redirectToAuthorization(cx) + next.ServeHTTP(w, req.WithContext(r.redirectToAuthorization(w, req))) + + return } // get the expiration of the new access token expiresIn := r.getAccessCookieExpiration(token, refresh) @@ -201,11 +201,12 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.Middlewar if r.config.EnableEncryptedToken { if accessToken, err = encodeText(accessToken, r.config.EncryptionKey); err != nil { r.log.Error("unable to encode the access token", zap.Error(err)) - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } } // step: inject the refreshed access token - r.dropAccessTokenCookie(cx.Request(), cx.Response().Writer, accessToken, expiresIn) + r.dropAccessTokenCookie(req.WithContext(ctx), w, accessToken, expiresIn) if r.useStore() { go func(old, new jose.JWT, encrypted string) { @@ -220,27 +221,31 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) echo.Middlewar } // update the with the new access token and inject into the context user.token = token - cx.Set(userContextName, user) + ctx = context.WithValue(req.Context(), contextScopeName, scope) } } - return next(cx) - } + + next.ServeHTTP(w, req.WithContext(ctx)) + }) } } // admissionMiddleware is responsible checking the access token against the protected resource -func (r *oauthProxy) admissionMiddleware(resource *Resource) echo.MiddlewareFunc { +func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) http.Handler { claimMatches := make(map[string]*regexp.Regexp) for k, v := range r.config.MatchClaims { claimMatches[k] = regexp.MustCompile(v) } - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - if found := cx.Get(revokeContextName); found != nil { - return nil + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // we don't need to continue is a decision has been made + scope := req.Context().Value(contextScopeName).(*RequestScope) + if scope.AccessDenied { + next.ServeHTTP(w, req) + return } - user := cx.Get(userContextName).(*userContext) + user := scope.Identity // step: we need to check the roles if roles := len(resource.Roles); roles > 0 { @@ -251,13 +256,13 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) echo.MiddlewareFunc zap.String("resource", resource.URL), zap.String("required", resource.getRoles())) - return r.accessForbidden(cx) + next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) + return } } // step: if we have any claim matching, lets validate the tokens has the claims for claimName, match := range claimMatches { - // step: if the claim is NOT in the token, we access deny value, found, err := user.claims.StringClaim(claimName) if err != nil { r.log.Error("unable to extract the claim from token", @@ -266,7 +271,8 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) echo.MiddlewareFunc zap.String("resource", resource.URL), zap.Error(err)) - return r.accessForbidden(cx) + next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) + return } if !found { @@ -276,7 +282,8 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) echo.MiddlewareFunc zap.String("email", user.email), zap.String("resource", resource.URL)) - return r.accessForbidden(cx) + next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) + return } // step: check the claim is the same @@ -289,7 +296,8 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) echo.MiddlewareFunc zap.String("required", match.String()), zap.String("resource", resource.URL)) - return r.accessForbidden(cx) + next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) + return } } @@ -299,48 +307,50 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) echo.MiddlewareFunc zap.Duration("expires", time.Until(user.expiresAt)), zap.String("resource", resource.URL)) - return next(cx) - } + next.ServeHTTP(w, req) + }) } } // headersMiddleware is responsible for add the authentication headers for the upstream -func (r *oauthProxy) headersMiddleware(custom []string) echo.MiddlewareFunc { +func (r *oauthProxy) headersMiddleware(custom []string) func(http.Handler) http.Handler { customClaims := make(map[string]string) for _, x := range custom { customClaims[x] = fmt.Sprintf("X-Auth-%s", toHeader(x)) } - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - if user := cx.Get(userContextName); user != nil { - id := user.(*userContext) - cx.Request().Header.Set("X-Auth-Email", id.email) - cx.Request().Header.Set("X-Auth-ExpiresIn", id.expiresAt.String()) - cx.Request().Header.Set("X-Auth-Roles", strings.Join(id.roles, ",")) - cx.Request().Header.Set("X-Auth-Subject", id.id) - cx.Request().Header.Set("X-Auth-Token", id.token.Encode()) - cx.Request().Header.Set("X-Auth-Userid", id.name) - cx.Request().Header.Set("X-Auth-Username", id.name) - - // step: add the authorization header if requested + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + scope := req.Context().Value(contextScopeName).(*RequestScope) + if scope.Identity != nil { + user := scope.Identity + req.Header.Set("X-Auth-Email", user.email) + req.Header.Set("X-Auth-ExpiresIn", user.expiresAt.String()) + req.Header.Set("X-Auth-Roles", strings.Join(user.roles, ",")) + req.Header.Set("X-Auth-Subject", user.id) + req.Header.Set("X-Auth-Token", user.token.Encode()) + req.Header.Set("X-Auth-Userid", user.name) + req.Header.Set("X-Auth-Username", user.name) + + // add the authorization header if requested if r.config.EnableAuthorizationHeader { - cx.Request().Header.Set("Authorization", fmt.Sprintf("Bearer %s", id.token.Encode())) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", user.token.Encode())) } - // step: inject any custom claims + // inject any custom claims for claim, header := range customClaims { - if claim, found := id.claims[claim]; found { - cx.Request().Header.Set(header, fmt.Sprintf("%v", claim)) + if claim, found := user.claims[claim]; found { + req.Header.Set(header, fmt.Sprintf("%v", claim)) } } } - return next(cx) - } + + next.ServeHTTP(w, req) + }) } } // securityMiddleware performs numerous security checks on the request -func (r *oauthProxy) securityMiddleware() echo.MiddlewareFunc { +func (r *oauthProxy) securityMiddleware(next http.Handler) http.Handler { r.log.Info("enabling the security filter middleware") secure := secure.New(secure.Options{ AllowedHosts: r.config.Hostnames, @@ -351,13 +361,31 @@ func (r *oauthProxy) securityMiddleware() echo.MiddlewareFunc { SSLRedirect: r.config.EnableHTTPSRedirect, }) - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(cx echo.Context) error { - if err := secure.Process(cx.Response().Writer, cx.Request()); err != nil { - r.log.Error("failed security middleware", zap.Error(err)) - return r.accessForbidden(cx) - } - return next(cx) + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if err := secure.Process(w, req); err != nil { + r.log.Warn("failed security middleware", zap.Error(err)) + next.ServeHTTP(w, req.WithContext(r.accessForbidden(w, req))) + return } - } + + next.ServeHTTP(w, req) + }) +} + +// proxyDenyMiddleware just block everything +func proxyDenyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + sc := req.Context().Value(contextScopeName) + var scope *RequestScope + if sc == nil { + scope = &RequestScope{} + } else { + scope = sc.(*RequestScope) + } + scope.AccessDenied = true + // update the request context + ctx := context.WithValue(req.Context(), contextScopeName, scope) + + next.ServeHTTP(w, req.WithContext(ctx)) + }) } diff --git a/middleware_test.go b/middleware_test.go index 7ff44cd6eb1fa295a516465f401a9ba7a1d05c19..13f9f3cb864cbdf8b8e14ecaa11d260300ce4154 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -27,7 +27,7 @@ import ( "github.com/gambol99/go-oidc/jose" "github.com/go-resty/resty" - "github.com/labstack/echo/middleware" + "github.com/rs/cors" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -79,7 +79,7 @@ func newFakeProxy(c *Config) *fakeProxy { auth := newFakeAuthServer() c.DiscoveryURL = auth.getLocation() c.RevocationEndpoint = auth.getRevocationURL() - c.Verbose = false + c.Verbose = true proxy, err := newProxy(c) if err != nil { panic("failed to create fake proxy service, error: " + err.Error()) @@ -330,12 +330,12 @@ func TestStrangeRoutingError(t *testing.T) { cfg := newFakeKeycloakConfig() cfg.Resources = []*Resource{ { - URL: "/api/v1/event/123456789", + URL: "/api/v1/events/123456789", Methods: allHTTPMethods, Roles: []string{"user"}, }, { - URL: "/api/v1/event/404", + URL: "/api/v1/events/404", Methods: allHTTPMethods, Roles: []string{"monitoring"}, }, @@ -352,26 +352,48 @@ func TestStrangeRoutingError(t *testing.T) { } requests := []fakeRequest{ { // should work - URI: "/api/v1/event/123456789", + URI: "/api/v1/events/123456789", HasToken: true, Redirects: true, Roles: []string{"user"}, ExpectedProxy: true, ExpectedCode: http.StatusOK, }, + { // should break with bad role + URI: "/api/v1/events/123456789", + HasToken: true, + Redirects: true, + Roles: []string{"bad_role"}, + ExpectedCode: http.StatusForbidden, + }, { // good - URI: "/api/v1/event/404", + URI: "/api/v1/events/404", HasToken: true, Redirects: false, - Roles: []string{"monitoring"}, + Roles: []string{"monitoring", "test"}, ExpectedProxy: true, ExpectedCode: http.StatusOK, }, - { // this should fail here + { // this should fail with no roles - hits catch all URI: "/api/v1/event/1000", Redirects: false, ExpectedCode: http.StatusUnauthorized, }, + { // this should fail with bad role - hits catch all + URI: "/api/v1/event/1000", + Redirects: false, + HasToken: true, + Roles: []string{"bad"}, + ExpectedCode: http.StatusForbidden, + }, + { // should work with catch-all + URI: "/api/v1/event/1000", + Redirects: false, + HasToken: true, + Roles: []string{"dev"}, + ExpectedProxy: true, + ExpectedCode: http.StatusOK, + }, } newFakeProxy(cfg).RunTests(t, requests) @@ -381,7 +403,7 @@ func TestNoProxyingRequests(t *testing.T) { c := newFakeKeycloakConfig() c.Resources = []*Resource{ { - URL: "*", + URL: "/*", Methods: allHTTPMethods, }, } @@ -593,12 +615,12 @@ func TestRolePermissionsMiddleware(t *testing.T) { Roles: []string{"bad_role"}, ExpectedCode: http.StatusForbidden, }, - { // token, wrong roles, no 'get' method (5) + { // token, but post method URI: "/test", Method: http.MethodPost, Redirects: false, HasToken: true, - Roles: []string{"bad_role"}, + Roles: []string{fakeTestRole}, ExpectedCode: http.StatusOK, ExpectedProxy: true, }, @@ -747,42 +769,52 @@ func TestRolePermissionsMiddleware(t *testing.T) { func TestCrossSiteHandler(t *testing.T) { cases := []struct { - Cors middleware.CORSConfig + Cors cors.Options Request fakeRequest }{ { - Cors: middleware.CORSConfig{ - AllowOrigins: []string{"*"}, + Cors: cors.Options{ + AllowedOrigins: []string{"*"}, }, Request: fakeRequest{ URI: fakeAuthAllURL, + Headers: map[string]string{ + "Origin": "127.0.0.1", + }, ExpectedHeaders: map[string]string{ "Access-Control-Allow-Origin": "*", }, }, }, { - Cors: middleware.CORSConfig{ - AllowOrigins: []string{"*", "https://examples.com"}, + Cors: cors.Options{ + AllowedOrigins: []string{"*", "https://examples.com"}, }, Request: fakeRequest{ URI: fakeAuthAllURL, + Headers: map[string]string{ + "Origin": "127.0.0.1", + }, ExpectedHeaders: map[string]string{ "Access-Control-Allow-Origin": "*", }, }, }, { - Cors: middleware.CORSConfig{ - AllowOrigins: []string{"*"}, - AllowMethods: []string{"GET", "POST"}, + Cors: cors.Options{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST"}, }, Request: fakeRequest{ URI: fakeAuthAllURL, Method: http.MethodOptions, + Headers: map[string]string{ + "Origin": "127.0.0.1", + "Access-Control-Request-Method": "GET", + }, ExpectedHeaders: map[string]string{ "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET,POST", + "Access-Control-Allow-Methods": "GET", }, }, }, @@ -791,11 +823,11 @@ func TestCrossSiteHandler(t *testing.T) { for _, c := range cases { cfg := newFakeKeycloakConfig() cfg.CorsCredentials = c.Cors.AllowCredentials - cfg.CorsExposedHeaders = c.Cors.ExposeHeaders - cfg.CorsHeaders = c.Cors.AllowHeaders + cfg.CorsExposedHeaders = c.Cors.ExposedHeaders + cfg.CorsHeaders = c.Cors.AllowedHeaders cfg.CorsMaxAge = time.Duration(time.Duration(c.Cors.MaxAge) * time.Second) - cfg.CorsMethods = c.Cors.AllowMethods - cfg.CorsOrigins = c.Cors.AllowOrigins + cfg.CorsMethods = c.Cors.AllowedMethods + cfg.CorsOrigins = c.Cors.AllowedOrigins newFakeProxy(cfg).RunTests(t, []fakeRequest{c.Request}) } @@ -955,6 +987,7 @@ func TestAdmissionHandlerRoles(t *testing.T) { newFakeProxy(cfg).RunTests(t, requests) } +// check to see if custom headers are hitting the upstream func TestCustomHeaders(t *testing.T) { uri := "/admin/test" requests := []struct { @@ -966,7 +999,7 @@ func TestCustomHeaders(t *testing.T) { "TestHeaderOne": "one", }, Request: fakeRequest{ - URI: "/test.html", + URI: "/gambol99.htm", ExpectedProxy: true, ExpectedProxyHeaders: map[string]string{ "TestHeaderOne": "one", diff --git a/misc.go b/misc.go index b0541fad02b2ab69f60a4e2644e9bb6404c9d235..41cfc1f873ad2e848092025cc057a3441c2eff05 100644 --- a/misc.go +++ b/misc.go @@ -16,6 +16,7 @@ limitations under the License. package main import ( + "context" "encoding/base64" "fmt" "net/http" @@ -23,58 +24,63 @@ import ( "time" "github.com/gambol99/go-oidc/jose" - "github.com/labstack/echo" "go.uber.org/zap" ) // revokeProxy is responsible to stopping the middleware from proxying the request -func (r *oauthProxy) revokeProxy(cx echo.Context) { - cx.Set(revokeContextName, true) +func (r *oauthProxy) revokeProxy(w http.ResponseWriter, req *http.Request) context.Context { + var scope *RequestScope + sc := req.Context().Value(contextScopeName) + switch sc { + case nil: + scope = &RequestScope{AccessDenied: true} + default: + scope = sc.(*RequestScope) + } + scope.AccessDenied = true + + return context.WithValue(req.Context(), contextScopeName, scope) } // accessForbidden redirects the user to the forbidden page -func (r *oauthProxy) accessForbidden(cx echo.Context) error { - r.revokeProxy(cx) - +func (r *oauthProxy) accessForbidden(w http.ResponseWriter, req *http.Request) context.Context { + w.WriteHeader(http.StatusForbidden) + // are we using a custom http template for 403? if r.config.hasCustomForbiddenPage() { - tplName := path.Base(r.config.ForbiddenPage) - err := cx.Render(http.StatusForbidden, tplName, r.config.Tags) - if err != nil { - r.log.Error("unable to render the template", - zap.Error(err), - zap.String("template", tplName)) + name := path.Base(r.config.ForbiddenPage) + if err := r.Render(w, name, r.config.Tags); err != nil { + r.log.Error("failed to render the template", zap.Error(err), zap.String("template", name)) } - - return err } - return cx.NoContent(http.StatusForbidden) + return r.revokeProxy(w, req) } // redirectToURL redirects the user and aborts the context -func (r *oauthProxy) redirectToURL(url string, cx echo.Context) error { - r.revokeProxy(cx) +func (r *oauthProxy) redirectToURL(url string, w http.ResponseWriter, req *http.Request) context.Context { + http.Redirect(w, req, url, http.StatusTemporaryRedirect) - return cx.Redirect(http.StatusTemporaryRedirect, url) + return r.revokeProxy(w, req) } // redirectToAuthorization redirects the user to authorization handler -func (r *oauthProxy) redirectToAuthorization(cx echo.Context) error { - r.revokeProxy(cx) - +func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Request) context.Context { if r.config.NoRedirects { - return cx.NoContent(http.StatusUnauthorized) + w.WriteHeader(http.StatusUnauthorized) + return r.revokeProxy(w, req) } // step: add a state referrer to the authorization page - authQuery := fmt.Sprintf("?state=%s", base64.StdEncoding.EncodeToString([]byte(cx.Request().URL.RequestURI()))) + authQuery := fmt.Sprintf("?state=%s", base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI()))) // step: if verification is switched off, we can't authorization if r.config.SkipTokenVerification { r.log.Error("refusing to redirection to authorization endpoint, skip token verification switched on") - return cx.NoContent(http.StatusForbidden) + w.WriteHeader(http.StatusForbidden) + return r.revokeProxy(w, req) } + r.redirectToURL(oauthURL+authorizationURL+authQuery, w, req) - return r.redirectToURL(oauthURL+authorizationURL+authQuery, cx) + return r.revokeProxy(w, req) } // getAccessCookieExpiration calucates the expiration of the access token cookie diff --git a/oauth_test.go b/oauth_test.go index ec1c4ed62ab131c1eeb4ed2aeee4c359eb989b5d..3d7c382a309c48e7c1a7088530bbebd6bdaec63b 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -17,6 +17,7 @@ package main import ( "crypto/x509" + "encoding/json" "encoding/pem" "fmt" "math/rand" @@ -29,7 +30,8 @@ import ( "github.com/gambol99/go-oidc/jose" "github.com/gambol99/go-oidc/oauth2" - "github.com/labstack/echo" + "github.com/pressly/chi" + "github.com/pressly/chi/middleware" "github.com/stretchr/testify/assert" ) @@ -110,14 +112,15 @@ func newFakeAuthServer() *fakeAuthServer { signer: jose.NewSignerRSA("test-kid", *privateKey), } - r := echo.New() - r.GET("auth/realms/hod-test/.well-known/openid-configuration", service.discoveryHandler) - r.GET("auth/realms/hod-test/protocol/openid-connect/certs", service.keysHandler) - r.GET("auth/realms/hod-test/protocol/openid-connect/token", service.tokenHandler) - r.POST("auth/realms/hod-test/protocol/openid-connect/token", service.tokenHandler) - r.GET("auth/realms/hod-test/protocol/openid-connect/auth", service.authHandler) - r.POST("auth/realms/hod-test/protocol/openid-connect/logout", service.logoutHandler) - r.GET("auth/realms/hod-test/protocol/openid-connect/userinfo", service.userInfoHandler) + r := chi.NewRouter() + r.Use(middleware.Recoverer) + r.Get("/auth/realms/hod-test/.well-known/openid-configuration", service.discoveryHandler) + r.Get("/auth/realms/hod-test/protocol/openid-connect/certs", service.keysHandler) + r.Get("/auth/realms/hod-test/protocol/openid-connect/token", service.tokenHandler) + r.Get("/auth/realms/hod-test/protocol/openid-connect/auth", service.authHandler) + r.Get("/auth/realms/hod-test/protocol/openid-connect/userinfo", service.userInfoHandler) + r.Post("/auth/realms/hod-test/protocol/openid-connect/logout", service.logoutHandler) + r.Post("/auth/realms/hod-test/protocol/openid-connect/token", service.tokenHandler) service.server = httptest.NewServer(r) location, err := url.Parse(service.server.URL) @@ -151,8 +154,8 @@ func (r *fakeAuthServer) setTokenExpiration(tm time.Duration) *fakeAuthServer { return r } -func (r *fakeAuthServer) discoveryHandler(cx echo.Context) error { - return cx.JSON(http.StatusOK, fakeDiscoveryResponse{ +func (r *fakeAuthServer) discoveryHandler(w http.ResponseWriter, req *http.Request) { + renderJSON(http.StatusOK, w, req, fakeDiscoveryResponse{ AuthorizationEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/auth", r.location.Host), EndSessionEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/logout", r.location.Host), Issuer: fmt.Sprintf("http://%s/auth/realms/hod-test", r.location.Host), @@ -169,47 +172,52 @@ func (r *fakeAuthServer) discoveryHandler(cx echo.Context) error { }) } -func (r *fakeAuthServer) keysHandler(cx echo.Context) error { - return cx.JSON(http.StatusOK, jose.JWKSet{Keys: []jose.JWK{r.key}}) +func (r *fakeAuthServer) keysHandler(w http.ResponseWriter, req *http.Request) { + renderJSON(http.StatusOK, w, req, jose.JWKSet{Keys: []jose.JWK{r.key}}) } -func (r *fakeAuthServer) authHandler(cx echo.Context) error { - state := cx.QueryParam("state") - redirect := cx.QueryParam("redirect_uri") +func (r *fakeAuthServer) authHandler(w http.ResponseWriter, req *http.Request) { + state := req.URL.Query().Get("state") + redirect := req.URL.Query().Get("redirect_uri") if redirect == "" { - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } if state == "" { state = "/" } redirectionURL := fmt.Sprintf("%s?state=%s&code=%s", redirect, state, getRandomString(32)) - return cx.Redirect(http.StatusTemporaryRedirect, redirectionURL) + http.Redirect(w, req, redirectionURL, http.StatusTemporaryRedirect) } -func (r *fakeAuthServer) logoutHandler(cx echo.Context) error { - if refreshToken := cx.FormValue("refresh_token"); refreshToken == "" { - return cx.NoContent(http.StatusBadRequest) +func (r *fakeAuthServer) logoutHandler(w http.ResponseWriter, req *http.Request) { + if refreshToken := req.FormValue("refresh_token"); refreshToken == "" { + w.WriteHeader(http.StatusBadRequest) + return } - return cx.NoContent(http.StatusNoContent) + w.WriteHeader(http.StatusNoContent) } -func (r *fakeAuthServer) userInfoHandler(cx echo.Context) error { - items := strings.Split(cx.Request().Header.Get("Authorization"), " ") +func (r *fakeAuthServer) userInfoHandler(w http.ResponseWriter, req *http.Request) { + items := strings.Split(req.Header.Get("Authorization"), " ") if len(items) != 2 { - return echo.ErrUnauthorized + w.WriteHeader(http.StatusUnauthorized) + return } decoded, err := jose.ParseJWT(items[1]) if err != nil { - return echo.ErrUnauthorized + w.WriteHeader(http.StatusUnauthorized) + return } claims, err := decoded.Claims() if err != nil { - return echo.ErrUnauthorized + w.WriteHeader(http.StatusUnauthorized) + return } - return cx.JSON(http.StatusOK, map[string]interface{}{ + renderJSON(http.StatusOK, w, req, map[string]interface{}{ "sub": claims["sub"], "name": claims["name"], "given_name": claims["given_name"], @@ -220,7 +228,7 @@ func (r *fakeAuthServer) userInfoHandler(cx echo.Context) error { }) } -func (r *fakeAuthServer) tokenHandler(cx echo.Context) error { +func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) { expires := time.Now().Add(r.expiration) unsigned := newTestToken(r.getLocation()) unsigned.setExpiration(expires) @@ -228,39 +236,42 @@ func (r *fakeAuthServer) tokenHandler(cx echo.Context) error { // sign the token with the private key token, err := jose.NewSignedJWT(unsigned.claims, r.signer) if err != nil { - return cx.NoContent(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) + return } - switch cx.FormValue("grant_type") { + switch req.FormValue("grant_type") { case oauth2.GrantTypeUserCreds: - username := cx.FormValue("username") - password := cx.FormValue("password") + username := req.FormValue("username") + password := req.FormValue("password") if username == "" || password == "" { - return cx.NoContent(http.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) + return } if username == validUsername && password == validPassword { - return cx.JSON(http.StatusOK, tokenResponse{ + renderJSON(http.StatusOK, w, req, tokenResponse{ IDToken: token.Encode(), AccessToken: token.Encode(), RefreshToken: token.Encode(), ExpiresIn: expires.UTC().Second(), }) + return } - return cx.JSON(http.StatusUnauthorized, map[string]string{ + renderJSON(http.StatusUnauthorized, w, req, map[string]string{ "error": "invalid_grant", "error_description": "Invalid user credentials", }) case oauth2.GrantTypeRefreshToken: fallthrough case oauth2.GrantTypeAuthCode: - return cx.JSON(http.StatusOK, tokenResponse{ + renderJSON(http.StatusOK, w, req, tokenResponse{ IDToken: token.Encode(), AccessToken: token.Encode(), RefreshToken: token.Encode(), ExpiresIn: expires.Second(), }) default: - return cx.NoContent(http.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) } } @@ -312,3 +323,12 @@ func getRandomString(n int) string { } return string(b) } + +func renderJSON(code int, w http.ResponseWriter, req *http.Request, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + if err := json.NewEncoder(w).Encode(data); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} diff --git a/server.go b/server.go index 2120fc190e098238400eb7b65238c9a8c800c9d8..2c9a37a6756fa44fb479d9a6d8699ed488c3ffbe 100644 --- a/server.go +++ b/server.go @@ -36,9 +36,10 @@ import ( "github.com/armon/go-proxyproto" "github.com/gambol99/go-oidc/oidc" "github.com/gambol99/goproxy" - "github.com/labstack/echo" - "github.com/labstack/echo/middleware" + "github.com/pressly/chi" + "github.com/pressly/chi/middleware" "github.com/prometheus/client_golang/prometheus" + "github.com/rs/cors" "go.uber.org/zap" ) @@ -148,40 +149,60 @@ func (r *oauthProxy) createReverseProxy() error { if err := r.createUpstreamProxy(r.endpoint); err != nil { return err } + engine := chi.NewRouter() + engine.MethodNotAllowed(emptyHandler) + engine.NotFound(emptyHandler) + engine.Use(middleware.Recoverer) + engine.Use(entrypointMiddleware) - // step: create the router - engine := echo.New() - engine.Pre(r.filterMiddleware()) - engine.Use(middleware.Recover()) - - if r.config.EnableProfiling { - r.log.Warn("enabling the debug profiling on /debug/pprof") - engine.Any("/debug/pprof/:name", r.debugHandler) - } if r.config.EnableLogging { - engine.Use(r.loggingMiddleware()) + engine.Use(r.loggingMiddleware) } if r.config.EnableMetrics { - engine.Use(r.metricsMiddleware()) + engine.Use(r.metricsMiddleware) } if r.config.EnableSecurityFilter { - engine.Use(r.securityMiddleware()) + engine.Use(r.securityMiddleware) } + if len(r.config.CorsOrigins) > 0 { - engine.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: r.config.CorsOrigins, - AllowMethods: r.config.CorsMethods, - AllowHeaders: r.config.CorsHeaders, + c := cors.New(cors.Options{ + AllowedOrigins: r.config.CorsOrigins, + AllowedMethods: r.config.CorsMethods, + AllowedHeaders: r.config.CorsHeaders, AllowCredentials: r.config.CorsCredentials, - ExposeHeaders: r.config.CorsExposedHeaders, - MaxAge: int(r.config.CorsMaxAge.Seconds())})) + ExposedHeaders: r.config.CorsExposedHeaders, + MaxAge: int(r.config.CorsMaxAge.Seconds()), + }) + engine.Use(c.Handler) } - // step: add the routing for aouth - engine.Group(oauthURL, r.proxyRevokeMiddleware()) - engine.Any(oauthURL+"/:name", r.oauthHandler) + engine.Use(r.proxyMiddleware) r.router = engine + // step: add the routing for oauth + engine.With(proxyDenyMiddleware).Route(oauthURL, func(e chi.Router) { + e.MethodNotAllowed(methodNotAllowHandlder) + e.Get(authorizationURL, r.oauthAuthorizationHandler) + e.Get(callbackURL, r.oauthCallbackHandler) + e.Get(expiredURL, r.expirationHandler) + e.Get(healthURL, r.healthHandler) + e.Get(logoutURL, r.logoutHandler) + e.Get(tokenURL, r.tokenHandler) + e.Post(loginURL, r.loginHandler) + if r.config.EnableMetrics { + e.Get(metricsURL, r.proxyMetricsHandler) + } + }) + + if r.config.EnableProfiling { + engine.With(proxyDenyMiddleware).Route(debugURL, func(e chi.Router) { + r.log.Warn("enabling the debug profiling on /debug/pprof") + e.Get("/{name}", r.debugHandler) + e.Post("/{name}", r.debugHandler) + }) + } + // step: load the templates if any if err := r.createTemplates(); err != nil { return err @@ -195,17 +216,27 @@ func (r *oauthProxy) createReverseProxy() error { zap.String("ammended", strings.TrimRight(x.URL, "/"))) } } + for _, x := range r.config.Resources { r.log.Info("protecting resource", zap.String("resource", x.String())) + e := engine.With( + r.authenticationMiddleware(x), + r.admissionMiddleware(x), + r.headersMiddleware(r.config.AddClaims)) + e.MethodNotAllowed(emptyHandler) switch x.WhiteListed { case false: - engine.Match(x.Methods, x.URL, emptyHandler, r.authenticationMiddleware(x), r.admissionMiddleware(x), r.headersMiddleware(r.config.AddClaims)) + for _, m := range x.Methods { + e.MethodFunc(m, x.URL, emptyHandler) + } default: - engine.Match(x.Methods, x.URL, emptyHandler) + for _, m := range x.Methods { + engine.MethodFunc(m, x.URL, emptyHandler) + } } } for name, value := range r.config.MatchClaims { - r.log.Info("the token must contain", zap.String("claim", name), zap.String("value", value)) + r.log.Info("token must contain", zap.String("claim", name), zap.String("value", value)) } if r.config.RedirectionURL == "" { r.log.Warn("no redirection url has been set, will use host headers") @@ -214,8 +245,6 @@ func (r *oauthProxy) createReverseProxy() error { r.log.Info("session access tokens will be encrypted") } - engine.Use(r.proxyMiddleware()) - return nil } @@ -224,7 +253,7 @@ func (r *oauthProxy) createForwardingProxy() error { r.log.Info("enabling forward signing mode, listening on", zap.String("interface", r.config.Listen)) if r.config.SkipUpstreamTLSVerify { - r.log.Warn("TLS verification switched off. In forward signing mode it's recommended you verify! (--skip-upstream-tls-verify=false)") + r.log.Warn("tls verification switched off. In forward signing mode it's recommended you verify! (--skip-upstream-tls-verify=false)") } if err := r.createUpstreamProxy(nil); err != nil { return err @@ -297,14 +326,15 @@ func (r *oauthProxy) Run() error { } // step: create the http server server := &http.Server{ - Addr: r.config.Listen, - Handler: r.router, + Addr: r.config.Listen, + Handler: r.router, + IdleTimeout: 120 * time.Second, } r.server = server r.listener = listener go func() { - r.log.Info("keycloak proxy service starting on", zap.String("interface", r.config.Listen)) + r.log.Info("keycloak proxy service starting", zap.String("interface", r.config.Listen)) if err = server.Serve(listener); err != nil { if err != http.ErrServerClosed { r.log.Fatal("failed to start the http service", zap.Error(err)) @@ -314,7 +344,7 @@ func (r *oauthProxy) Run() error { // step: are we running http service as well? if r.config.ListenHTTP != "" { - r.log.Info("keycloak proxy http service starting on", zap.String("interface", r.config.ListenHTTP)) + r.log.Info("keycloak proxy http service starting", zap.String("interface", r.config.ListenHTTP)) httpListener, err := r.createHTTPListener(listenerConfig{ listen: r.config.ListenHTTP, proxyProtocol: r.config.EnableProxyProtocol, @@ -351,7 +381,7 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er var listener net.Listener var err error - // step: are we create a unix socket or tcp listener? + // are we create a unix socket or tcp listener? if strings.HasPrefix(config.listen, "unix://") { socket := strings.Trim(config.listen, "unix://") if exists := fileExists(socket); exists { @@ -369,22 +399,22 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er } } - // step: does it require proxy protocol? + // does it require proxy protocol? if config.proxyProtocol { r.log.Info("enabling the proxy protocol on listener", zap.String("interface", config.listen)) listener = &proxyproto.Listener{Listener: listener} } - // step: does the socket require TLS? + // does the socket require TLS? if config.certificate != "" && config.privateKey != "" { r.log.Info("tls support enabled", zap.String("certificate", config.certificate), zap.String("private_key", config.privateKey)) - // step: creating a certificate rotation + // creating a certificate rotation rotate, err := newCertificateRotator(config.certificate, config.privateKey, r.log) if err != nil { return nil, err } - // step: start watching the files for changes + // start watching the files for changes if err := rotate.watch(); err != nil { return nil, err } @@ -394,7 +424,7 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er } listener = tls.NewListener(listener, tlsConfig) - // step: are we doing mutual tls? + // are we doing mutual tls? if config.clientCert != "" { caCert, err := ioutil.ReadFile(config.clientCert) if err != nil { @@ -429,11 +459,8 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { upstream.Host = "domain-sock" upstream.Scheme = "http" } - // create the upstream tls configure - tlsConfig := &tls.Config{ - InsecureSkipVerify: r.config.SkipUpstreamTLSVerify, - } + tlsConfig := &tls.Config{InsecureSkipVerify: r.config.SkipUpstreamTLSVerify} // are we using a client certificate // @TODO provide a means of reload on the client certificate when it expires. I'm not sure if it's just a @@ -482,17 +509,11 @@ func (r *oauthProxy) createTemplates() error { if len(list) > 0 { r.log.Info("loading the custom templates", zap.String("templates", strings.Join(list, ","))) r.templates = template.Must(template.ParseFiles(list...)) - r.router.(*echo.Echo).Renderer = r } return nil } -// Render implements the echo Render interface -func (r *oauthProxy) Render(w io.Writer, name string, data interface{}, c echo.Context) error { - return r.templates.ExecuteTemplate(w, name, data) -} - // newOpenIDClient initializes the openID configuration, note: the redirection url is deliberately left blank // in order to retrieve it from the host header on request func (r *oauthProxy) newOpenIDClient() (*oidc.Client, oidc.ProviderConfig, *http.Client, error) { @@ -554,3 +575,8 @@ func (r *oauthProxy) newOpenIDClient() (*oidc.Client, oidc.ProviderConfig, *http return client, config, hc, nil } + +// Render implements the echo Render interface +func (r *oauthProxy) Render(w io.Writer, name string, data interface{}) error { + return r.templates.ExecuteTemplate(w, name, data) +} diff --git a/server_test.go b/server_test.go index b37e7fbc456c03efce7c05a8c3743dc1edcfecae..36ff79d69076f57039589ce342bb1fd7e3a8e4d9 100644 --- a/server_test.go +++ b/server_test.go @@ -97,7 +97,6 @@ func TestReverseProxyHeaders(t *testing.T) { "X-Auth-Token": signed.Encode(), "X-Auth-Userid": "rjayawardene", "X-Auth-Username": "rjayawardene", - "X-Forwarded-For": "127.0.0.1", }, ExpectedCode: http.StatusOK, }, @@ -138,7 +137,7 @@ func TestForbiddenTemplate(t *testing.T) { } requests := []fakeRequest{ { - URI: "/", + URI: "/test", Redirects: false, HasToken: true, ExpectedCode: http.StatusForbidden, diff --git a/utils.go b/utils.go index ec6ea218309f059f0243062d590a6e4893b699da..9d30d06ca622aaebd7ae609e14700dc09702b7d8 100644 --- a/utils.go +++ b/utils.go @@ -41,21 +41,20 @@ import ( "unicode/utf8" "github.com/gambol99/go-oidc/jose" - "github.com/labstack/echo" "github.com/urfave/cli" "gopkg.in/yaml.v2" ) var ( allHTTPMethods = []string{ - echo.DELETE, - echo.GET, - echo.HEAD, - echo.OPTIONS, - echo.PATCH, - echo.POST, - echo.PUT, - echo.TRACE, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + http.MethodPut, + http.MethodTrace, } ) @@ -377,3 +376,16 @@ func getHashKey(token *jose.JWT) string { func printError(message string, args ...interface{}) *cli.ExitError { return cli.NewExitError(fmt.Sprintf("[error] "+message, args...), 1) } + +// realIP retrieves the client ip address from a http request +func realIP(req *http.Request) string { + ra := req.RemoteAddr + if ip := req.Header.Get(headerXForwardedFor); ip != "" { + ra = strings.Split(ip, ", ")[0] + } else if ip := req.Header.Get(headerXRealIP); ip != "" { + ra = ip + } else { + ra, _, _ = net.SplitHostPort(ra) + } + return ra +}