diff --git a/forwarding.go b/forwarding.go index 92b61b3345e95b7c5ab093abff15c296b986c325..4c25eda8aebc72a2c90a76af6b20a6b197403a34 100644 --- a/forwarding.go +++ b/forwarding.go @@ -150,7 +150,7 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { zap.String("expires", state.expiration.Format(time.RFC3339))) // step: attempt to refresh the access - token, expiration, err := getRefreshedToken(r.client, state.refresh) + token, newRefreshToken, expiration, _, err := getRefreshedToken(r.client, state.refresh) if err != nil { state.login = true switch err { @@ -169,6 +169,9 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { state.expiration = expiration state.wait = true state.login = false + if newRefreshToken != "" { + state.refresh = newRefreshToken + } // step: add some debugging r.log.Info("successfully refreshed the access token", @@ -193,7 +196,7 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { duration := getWithin(state.expiration, 0.85) r.log.Info("waiting for expiration of access token", zap.String("token_expiration", state.expiration.Format(time.RFC3339)), - zap.String("renewel_duration", duration.String())) + zap.String("renewal_duration", duration.String())) <-time.After(duration) } diff --git a/handlers.go b/handlers.go index 64a5c55754793f69eb74d6ef9b8f8a422b09f909..bc7989a0417e7475bc094858a9bfc22f7d4419d6 100644 --- a/handlers.go +++ b/handlers.go @@ -489,7 +489,7 @@ func (r *oauthProxy) proxyMetricsHandler(w http.ResponseWriter, req *http.Reques } // retrieveRefreshToken retrieves the refresh token from store or cookie -func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) (token, ecrypted string, err error) { +func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) (token, encrypted string, err error) { switch r.useStore() { case true: token, err = r.GetRefreshToken(user.token) @@ -500,7 +500,7 @@ func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) return } - ecrypted = token // returns encryped, avoid encoding twice + encrypted = token // returns encrypted, avoids encoding twice token, err = decodeText(token, r.config.EncryptionKey) return } diff --git a/middleware.go b/middleware.go index f1ea568aaf192ecfe8184bd59a6128c91bc9334d..fff6864e93f092bcedc3f663bf508833f6796d25 100644 --- a/middleware.go +++ b/middleware.go @@ -167,8 +167,15 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Hand return } - // attempt to refresh the access token - token, exp, err := getRefreshedToken(r.client, refresh) + // attempt to refresh the access token, possibly with a renewed refresh token + // + // NOTE: atm, this does not retrieve explicit refresh token expiry from oauth2, + // and take identity expiry instead: with keycloak, they are the same and equal to + // "SSO session idle" keycloak setting. + // + // exp: expiration of the access token + // expiresIn: expiration of the ID token + token, newRefreshToken, accessExpiresAt, refreshExpiresIn, err := getRefreshedToken(r.client, refresh) if err != nil { switch err { case ErrRefreshTokenExpired: @@ -184,14 +191,24 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Hand return } - // get the expiration of the new access token - expiresIn := r.getAccessCookieExpiration(token, refresh) + + accessExpiresIn := time.Until(accessExpiresAt) + + // get the expiration of the new refresh token + if newRefreshToken != "" { + refresh = newRefreshToken + } + if refreshExpiresIn == 0 { + // refresh token expiry claims not available: try to parse refresh token + refreshExpiresIn = r.getAccessCookieExpiration(token, refresh) + } r.log.Info("injecting the refreshed access token cookie", zap.String("client_ip", clientIP), zap.String("cookie_name", r.config.CookieAccessName), zap.String("email", user.email), - zap.Duration("expires_in", time.Until(exp))) + zap.Duration("refresh_expires_in", refreshExpiresIn), + zap.Duration("expires_in", accessExpiresIn)) accessToken := token.Encode() if r.config.EnableEncryptedToken || r.config.ForceEncryptedCookie { @@ -202,7 +219,20 @@ func (r *oauthProxy) authenticationMiddleware(resource *Resource) func(http.Hand } } // step: inject the refreshed access token - r.dropAccessTokenCookie(req.WithContext(ctx), w, accessToken, expiresIn) + r.dropAccessTokenCookie(req.WithContext(ctx), w, accessToken, accessExpiresIn) + + // step: inject the renewed refresh token + if newRefreshToken != "" { + r.log.Debug("renew refresh cookie with new refresh token", + zap.Duration("refresh_expires_in", refreshExpiresIn)) + encryptedRefreshToken, err := encodeText(newRefreshToken, r.config.EncryptionKey) + if err != nil { + r.log.Error("failed to encrypt the refresh token", zap.Error(err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + r.dropRefreshTokenCookie(req.WithContext(ctx), w, encryptedRefreshToken, refreshExpiresIn) + } if r.useStore() { go func(old, new jose.JWT, encrypted string) { diff --git a/misc.go b/misc.go index 98c249493ca6bde600d5b8dedd04081a264f4ab7..3f3f369fd11c23b2bbda00bf741eb15b56d57050 100644 --- a/misc.go +++ b/misc.go @@ -115,10 +115,10 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re return r.revokeProxy(w, req) } -// getAccessCookieExpiration calucates the expiration of the access token cookie +// getAccessCookieExpiration calculates the expiration of the access token cookie func (r *oauthProxy) getAccessCookieExpiration(token jose.JWT, refresh string) time.Duration { // notes: by default the duration of the access token will be the configuration option, if - // however we can decode the refresh token, we will set the duration to the duraction of the + // however we can decode the refresh token, we will set the duration to the duration of the // refresh token duration := r.config.AccessTokenDuration if _, ident, err := parseToken(refresh); err == nil { @@ -126,6 +126,9 @@ func (r *oauthProxy) getAccessCookieExpiration(token jose.JWT, refresh string) t if delta > 0 { duration = delta } + r.log.Debug("parsed refresh token with new duration", zap.Duration("new duration", delta)) + } else { + r.log.Debug("refresh token is opaque and cannot be used to extend calculated duration") } return duration diff --git a/oauth.go b/oauth.go index 5e1a53b651ff14bba2473b2a20f8451f049cd23c..87ab8ced7a70e63fc81d4d7d58318a6e91a773c2 100644 --- a/oauth.go +++ b/oauth.go @@ -56,26 +56,44 @@ func verifyToken(client *oidc.Client, token jose.JWT) error { return nil } -// getRefreshedToken attempts to refresh the access token, returning the parsed token and the time it expires or a error -func getRefreshedToken(client *oidc.Client, t string) (jose.JWT, time.Time, error) { +// getRefreshedToken attempts to refresh the access token, returning the parsed token, optionally with a renewed +// refresh token and the time the access and refresh tokens expire +// +// NOTE: we may be able to extract the specific (non-standard) claim refresh_expires_in and refresh_expires +// from response.RawBody. +// When not available, keycloak provides us with the same (for now) expiry value for ID token. +func getRefreshedToken(client *oidc.Client, t string) (jose.JWT, string, time.Time, time.Duration, error) { cl, err := client.OAuthClient() if err != nil { - return jose.JWT{}, time.Time{}, err + return jose.JWT{}, "", time.Time{}, time.Duration(0), err } response, err := getToken(cl, oauth2.GrantTypeRefreshToken, t) if err != nil { - if strings.Contains(err.Error(), "token expired") { - return jose.JWT{}, time.Time{}, ErrRefreshTokenExpired + if strings.Contains(err.Error(), "refresh token has expired") { + return jose.JWT{}, "", time.Time{}, time.Duration(0), ErrRefreshTokenExpired } - return jose.JWT{}, time.Time{}, err + return jose.JWT{}, "", time.Time{}, time.Duration(0), err } + // extracts non-standard claims about refresh token, to get refresh token expiry + var ( + refreshExpiresIn time.Duration + extraClaims struct { + RefreshExpiresIn json.Number `json:"refresh_expires_in"` + } + ) + _ = json.Unmarshal(response.RawBody, &extraClaims) + if extraClaims.RefreshExpiresIn != "" { + if asInt, erj := extraClaims.RefreshExpiresIn.Int64(); erj == nil { + refreshExpiresIn = time.Duration(asInt) * time.Second + } + } token, identity, err := parseToken(response.AccessToken) if err != nil { - return jose.JWT{}, time.Time{}, err + return jose.JWT{}, "", time.Time{}, time.Duration(0), err } - return token, identity.ExpiresAt, nil + return token, response.RefreshToken, identity.ExpiresAt, refreshExpiresIn, nil } // exchangeAuthenticationCode exchanges the authentication code with the oauth server for a access token @@ -130,7 +148,7 @@ func getToken(client *oauth2.Client, grantType, code string) (oauth2.TokenRespon return token, err } -// parseToken retrieve the user identity from the token +// parseToken retrieves the user identity from the token func parseToken(t string) (jose.JWT, *oidc.Identity, error) { token, err := jose.ParseJWT(t) if err != nil {