diff --git a/cookies.go b/cookies.go index 74fdf721b3fc584574c1de1ebba3233de8d1ebb0..d1db5d4d9d29100eb1f5565d16d3e720131235f3 100644 --- a/cookies.go +++ b/cookies.go @@ -16,6 +16,7 @@ limitations under the License. package main import ( + "encoding/base64" "net/http" "strconv" "strings" @@ -89,6 +90,8 @@ func (r *oauthProxy) dropRefreshTokenCookie(req *http.Request, w http.ResponseWr // dropStateParameterCookie drops a state parameter cookie into the response func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) string { uuid := uuid.NewV4().String() + requestURI := base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI())) + r.dropCookie(w, req.Host, "request_uri", requestURI, 0) r.dropCookie(w, req.Host, "OAuth_Token_Request_State", uuid, 0) return uuid } diff --git a/handlers.go b/handlers.go index 99541d9c7997e070b55ec738cc86c8c8eafdcd83..b8761e632597ee288db78005669135f3fea6dc59 100644 --- a/handlers.go +++ b/handlers.go @@ -201,24 +201,20 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque r.dropAccessTokenCookie(req, w, accessToken, time.Until(identity.ExpiresAt)) } - // step: decode the state variable - state := "/" + // step: decode the request variable + redirectURI := "/" if req.URL.Query().Get("state") != "" { - decoded, err := base64.StdEncoding.DecodeString(req.URL.Query().Get("state")) - if err != nil { - r.log.Warn("unable to decode the state parameter", - zap.String("state", req.URL.Query().Get("state")), - zap.Error(err)) - } else { - state = string(decoded) + if encodedRequestURI, _ := req.Cookie("request_uri"); encodedRequestURI != nil { + decoded, _ := base64.StdEncoding.DecodeString(encodedRequestURI.Value) + redirectURI = string(decoded) } } if r.config.BaseURI != "" { // assuming state starts with slash - state = r.config.BaseURI + state + redirectURI = r.config.BaseURI + redirectURI } - r.redirectToURL(state, w, req, http.StatusTemporaryRedirect) + r.redirectToURL(redirectURI, w, req, http.StatusTemporaryRedirect) } // loginHandler provide's a generic endpoint for clients to perform a user_credentials login to the provider diff --git a/handlers_test.go b/handlers_test.go index a796a061b7b01ae178a67a878d98d1173e872bff..1d827fe21032b76cbe0c72d278c2c290fe8f83da 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -304,7 +304,7 @@ func TestCallbackURL(t *testing.T) { { URI: cfg.WithOAuthURI(callbackURL) + "?code=fake&state=L2FkbWlu", ExpectedCookies: map[string]string{cfg.CookieAccessName: ""}, - ExpectedLocation: "/admin", + ExpectedLocation: "/", ExpectedCode: http.StatusTemporaryRedirect, }, }