diff --git a/doc.go b/doc.go index cb939d3f0461a1af4075a502e334b6ede32f32e2..9c64ace45d968762c10231cdfba55fbac4d8bd2a 100644 --- a/doc.go +++ b/doc.go @@ -387,7 +387,7 @@ type userContext struct { // the id of the user id string // the audience for the token - audience string + audiences []string // whether the context is from a session cookie or authorization header bearerToken bool // the claims associated to the token diff --git a/middleware.go b/middleware.go index d3612586fbd40c1613b46c4b102b9a97d200d150..8d77eacfbbe9cc1ff3476a937cd6b41c84689216 100644 --- a/middleware.go +++ b/middleware.go @@ -373,7 +373,7 @@ func (r *oauthProxy) identityHeadersMiddleware(custom []string) func(http.Handle scope := req.Context().Value(contextScopeName).(*RequestScope) if scope.Identity != nil { user := scope.Identity - req.Header.Set("X-Auth-Audience", user.audience) + req.Header.Set("X-Auth-Audience", strings.Join(user.audiences, ",")) req.Header.Set("X-Auth-Email", user.email) req.Header.Set("X-Auth-ExpiresIn", user.expiresAt.String()) req.Header.Set("X-Auth-Groups", strings.Join(user.groups, ",")) diff --git a/user_context.go b/user_context.go index 6ff75e089612561acca2153b82746dda83e6c797..4ec8dd3b848241ad647859fa5ecae245ab15ddad 100644 --- a/user_context.go +++ b/user_context.go @@ -40,9 +40,17 @@ func extractIdentity(token jose.JWT) (*userContext, error) { if err != nil || !found { preferredName = identity.Email } - audience, found, err := claims.StringClaim(claimAudience) - if err != nil || !found { - return nil, ErrNoTokenAudience + + var audiences []string + aud, found, err := claims.StringClaim(claimAudience) + if err == nil && found { + audiences = append(audiences, aud) + } else { + if aud, found, err := claims.StringsClaim(claimAudience); err != nil || !found { + return nil, ErrNoTokenAudience + } else { + audiences = aud + } } // @step: extract the realm roles @@ -74,7 +82,7 @@ func extractIdentity(token jose.JWT) (*userContext, error) { } return &userContext{ - audience: audience, + audiences: audiences, claims: claims, email: identity.Email, expiresAt: identity.ExpiresAt, @@ -87,9 +95,20 @@ func extractIdentity(token jose.JWT) (*userContext, error) { }, nil } +// backported from https://github.com/gambol99/go-oidc/blob/master/oidc/verification.go#L28-L37 +// I'll raise another PR to make it public in the go-oidc package so we can just use `oidc.ContainsString()` +func containsString(needle string, haystack []string) bool { + for _, v := range haystack { + if v == needle { + return true + } + } + return false +} + // isAudience checks the audience func (r *userContext) isAudience(aud string) bool { - return r.audience == aud + return containsString(aud, r.audiences) } // getRoles returns a list of roles diff --git a/user_context_test.go b/user_context_test.go index b28c45f50791be86c90b2d3a65ae5fd098fd5402..f85103c1c5afffc9d6944f70bdc519a31164e513 100644 --- a/user_context_test.go +++ b/user_context_test.go @@ -24,7 +24,7 @@ import ( func TestIsAudience(t *testing.T) { user := &userContext{ - audience: "test", + audiences: []string{"test", "test2"}, } if !user.isAudience("test") { t.Error("return should not have been false") @@ -32,6 +32,9 @@ func TestIsAudience(t *testing.T) { if user.isAudience("test1") { t.Error("return should not have been true") } + if !user.isAudience("test2") { + t.Error("return should not have been false") + } } func TestGetUserRoles(t *testing.T) {