From 0a55c3d7e59f58df3178a049ff65f1ef90bf042e Mon Sep 17 00:00:00 2001 From: Stephane Tang <hi@stang.sh> Date: Fri, 20 Jul 2018 17:58:10 +0100 Subject: [PATCH] allow multiple audiences Signed-off-by: Stephane Tang <hi@stang.sh> --- doc.go | 2 +- middleware.go | 2 +- user_context.go | 29 ++++++++++++++++++++++++----- user_context_test.go | 5 ++++- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/doc.go b/doc.go index cb939d3..9c64ace 100644 --- a/doc.go +++ b/doc.go @@ -387,7 +387,7 @@ type userContext struct { // the id of the user id string // the audience for the token - audience string + audiences []string // whether the context is from a session cookie or authorization header bearerToken bool // the claims associated to the token diff --git a/middleware.go b/middleware.go index d361258..8d77eac 100644 --- a/middleware.go +++ b/middleware.go @@ -373,7 +373,7 @@ func (r *oauthProxy) identityHeadersMiddleware(custom []string) func(http.Handle scope := req.Context().Value(contextScopeName).(*RequestScope) if scope.Identity != nil { user := scope.Identity - req.Header.Set("X-Auth-Audience", user.audience) + req.Header.Set("X-Auth-Audience", strings.Join(user.audiences, ",")) req.Header.Set("X-Auth-Email", user.email) req.Header.Set("X-Auth-ExpiresIn", user.expiresAt.String()) req.Header.Set("X-Auth-Groups", strings.Join(user.groups, ",")) diff --git a/user_context.go b/user_context.go index 6ff75e0..4ec8dd3 100644 --- a/user_context.go +++ b/user_context.go @@ -40,9 +40,17 @@ func extractIdentity(token jose.JWT) (*userContext, error) { if err != nil || !found { preferredName = identity.Email } - audience, found, err := claims.StringClaim(claimAudience) - if err != nil || !found { - return nil, ErrNoTokenAudience + + var audiences []string + aud, found, err := claims.StringClaim(claimAudience) + if err == nil && found { + audiences = append(audiences, aud) + } else { + if aud, found, err := claims.StringsClaim(claimAudience); err != nil || !found { + return nil, ErrNoTokenAudience + } else { + audiences = aud + } } // @step: extract the realm roles @@ -74,7 +82,7 @@ func extractIdentity(token jose.JWT) (*userContext, error) { } return &userContext{ - audience: audience, + audiences: audiences, claims: claims, email: identity.Email, expiresAt: identity.ExpiresAt, @@ -87,9 +95,20 @@ func extractIdentity(token jose.JWT) (*userContext, error) { }, nil } +// backported from https://github.com/gambol99/go-oidc/blob/master/oidc/verification.go#L28-L37 +// I'll raise another PR to make it public in the go-oidc package so we can just use `oidc.ContainsString()` +func containsString(needle string, haystack []string) bool { + for _, v := range haystack { + if v == needle { + return true + } + } + return false +} + // isAudience checks the audience func (r *userContext) isAudience(aud string) bool { - return r.audience == aud + return containsString(aud, r.audiences) } // getRoles returns a list of roles diff --git a/user_context_test.go b/user_context_test.go index b28c45f..f85103c 100644 --- a/user_context_test.go +++ b/user_context_test.go @@ -24,7 +24,7 @@ import ( func TestIsAudience(t *testing.T) { user := &userContext{ - audience: "test", + audiences: []string{"test", "test2"}, } if !user.isAudience("test") { t.Error("return should not have been false") @@ -32,6 +32,9 @@ func TestIsAudience(t *testing.T) { if user.isAudience("test1") { t.Error("return should not have been true") } + if !user.isAudience("test2") { + t.Error("return should not have been false") + } } func TestGetUserRoles(t *testing.T) { -- GitLab