diff --git a/handlers_test.go b/handlers_test.go index b4aa6c9adf1893881cacc77d5b605fbbe6b4a20d..be9799691b1dbb4813fe455d2e619d39f26b1963 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -138,7 +138,10 @@ func TestLoginHandler(t *testing.T) { func TestLogoutHandlerBadRequest(t *testing.T) { requests := []fakeRequest{ - {URI: newFakeKeycloakConfig().WithOAuthURI(logoutURL), ExpectedCode: http.StatusBadRequest}, + { + URI: newFakeKeycloakConfig().WithOAuthURI(logoutURL), + ExpectedCode: http.StatusUnauthorized, + }, } newFakeProxy(nil).RunTests(t, requests) } @@ -148,18 +151,18 @@ func TestLogoutHandlerBadToken(t *testing.T) { requests := []fakeRequest{ { URI: c.WithOAuthURI(logoutURL), - ExpectedCode: http.StatusBadRequest, + ExpectedCode: http.StatusUnauthorized, }, { URI: c.WithOAuthURI(logoutURL), HasCookieToken: true, RawToken: "this.is.a.bad.token", - ExpectedCode: http.StatusBadRequest, + ExpectedCode: http.StatusUnauthorized, }, { URI: c.WithOAuthURI(logoutURL), RawToken: "this.is.a.bad.token", - ExpectedCode: http.StatusBadRequest, + ExpectedCode: http.StatusUnauthorized, }, } newFakeProxy(nil).RunTests(t, requests) @@ -185,20 +188,22 @@ func TestLogoutHandlerGood(t *testing.T) { func TestTokenHandler(t *testing.T) { uri := newFakeKeycloakConfig().WithOAuthURI(tokenURL) + goodToken := newTestToken("example").getToken() requests := []fakeRequest{ { URI: uri, HasToken: true, + RawToken: (&goodToken).Encode(), ExpectedCode: http.StatusOK, }, { URI: uri, - ExpectedCode: http.StatusBadRequest, + ExpectedCode: http.StatusUnauthorized, }, { URI: uri, RawToken: "niothing", - ExpectedCode: http.StatusBadRequest, + ExpectedCode: http.StatusUnauthorized, }, { URI: uri, diff --git a/middleware.go b/middleware.go index 74be12c7a0394aa9659448315158c9bf5f7a825d..7adcde95d9930c465d62b44227847dc6ab87d768 100644 --- a/middleware.go +++ b/middleware.go @@ -98,7 +98,7 @@ func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler { } // authenticationMiddleware is responsible for verifying the access token -func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Handler) http.Handler { +func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { clientIP := req.RemoteAddr diff --git a/server.go b/server.go index b169579ff0b5c4074c9204bc1eb6afb9e4125628..fa6d79ebccfb69d1348d5b176ad4e5c1e449180d 100644 --- a/server.go +++ b/server.go @@ -204,8 +204,8 @@ func (r *oauthProxy) createReverseProxy() error { e.Get(callbackURL, r.oauthCallbackHandler) e.Get(expiredURL, r.expirationHandler) e.Get(healthURL, r.healthHandler) - e.Get(logoutURL, r.logoutHandler) - e.Get(tokenURL, r.tokenHandler) + e.With(r.authenticationMiddleware()).Get(logoutURL, r.logoutHandler) + e.With(r.authenticationMiddleware()).Get(tokenURL, r.tokenHandler) e.Post(loginURL, r.loginHandler) if r.config.EnableMetrics { r.log.Info("enabled the service metrics middleware", zap.String("path", r.config.WithOAuthURI(metricsURL))) @@ -260,7 +260,7 @@ func (r *oauthProxy) createReverseProxy() error { for _, x := range r.config.Resources { r.log.Info("protecting resource", zap.String("resource", x.String())) e := engine.With( - r.authenticationMiddleware(x), + r.authenticationMiddleware(), r.admissionMiddleware(x), r.identityHeadersMiddleware(r.config.AddClaims))