From 0a55c3d7e59f58df3178a049ff65f1ef90bf042e Mon Sep 17 00:00:00 2001
From: Stephane Tang <hi@stang.sh>
Date: Fri, 20 Jul 2018 17:58:10 +0100
Subject: [PATCH] allow multiple audiences

Signed-off-by: Stephane Tang <hi@stang.sh>
---
 doc.go               |  2 +-
 middleware.go        |  2 +-
 user_context.go      | 29 ++++++++++++++++++++++++-----
 user_context_test.go |  5 ++++-
 4 files changed, 30 insertions(+), 8 deletions(-)

diff --git a/doc.go b/doc.go
index cb939d3..9c64ace 100644
--- a/doc.go
+++ b/doc.go
@@ -387,7 +387,7 @@ type userContext struct {
 	// the id of the user
 	id string
 	// the audience for the token
-	audience string
+	audiences []string
 	// whether the context is from a session cookie or authorization header
 	bearerToken bool
 	// the claims associated to the token
diff --git a/middleware.go b/middleware.go
index d361258..8d77eac 100644
--- a/middleware.go
+++ b/middleware.go
@@ -373,7 +373,7 @@ func (r *oauthProxy) identityHeadersMiddleware(custom []string) func(http.Handle
 			scope := req.Context().Value(contextScopeName).(*RequestScope)
 			if scope.Identity != nil {
 				user := scope.Identity
-				req.Header.Set("X-Auth-Audience", user.audience)
+				req.Header.Set("X-Auth-Audience", strings.Join(user.audiences, ","))
 				req.Header.Set("X-Auth-Email", user.email)
 				req.Header.Set("X-Auth-ExpiresIn", user.expiresAt.String())
 				req.Header.Set("X-Auth-Groups", strings.Join(user.groups, ","))
diff --git a/user_context.go b/user_context.go
index 6ff75e0..4ec8dd3 100644
--- a/user_context.go
+++ b/user_context.go
@@ -40,9 +40,17 @@ func extractIdentity(token jose.JWT) (*userContext, error) {
 	if err != nil || !found {
 		preferredName = identity.Email
 	}
-	audience, found, err := claims.StringClaim(claimAudience)
-	if err != nil || !found {
-		return nil, ErrNoTokenAudience
+
+	var audiences []string
+	aud, found, err := claims.StringClaim(claimAudience)
+	if err == nil && found {
+		audiences = append(audiences, aud)
+	} else {
+		if aud, found, err := claims.StringsClaim(claimAudience); err != nil || !found {
+			return nil, ErrNoTokenAudience
+		} else {
+			audiences = aud
+		}
 	}
 
 	// @step: extract the realm roles
@@ -74,7 +82,7 @@ func extractIdentity(token jose.JWT) (*userContext, error) {
 	}
 
 	return &userContext{
-		audience:      audience,
+		audiences:     audiences,
 		claims:        claims,
 		email:         identity.Email,
 		expiresAt:     identity.ExpiresAt,
@@ -87,9 +95,20 @@ func extractIdentity(token jose.JWT) (*userContext, error) {
 	}, nil
 }
 
+// backported from https://github.com/gambol99/go-oidc/blob/master/oidc/verification.go#L28-L37
+// I'll raise another PR to make it public in the go-oidc package so we can just use `oidc.ContainsString()`
+func containsString(needle string, haystack []string) bool {
+	for _, v := range haystack {
+		if v == needle {
+			return true
+		}
+	}
+	return false
+}
+
 // isAudience checks the audience
 func (r *userContext) isAudience(aud string) bool {
-	return r.audience == aud
+	return containsString(aud, r.audiences)
 }
 
 // getRoles returns a list of roles
diff --git a/user_context_test.go b/user_context_test.go
index b28c45f..f85103c 100644
--- a/user_context_test.go
+++ b/user_context_test.go
@@ -24,7 +24,7 @@ import (
 
 func TestIsAudience(t *testing.T) {
 	user := &userContext{
-		audience: "test",
+		audiences: []string{"test", "test2"},
 	}
 	if !user.isAudience("test") {
 		t.Error("return should not have been false")
@@ -32,6 +32,9 @@ func TestIsAudience(t *testing.T) {
 	if user.isAudience("test1") {
 		t.Error("return should not have been true")
 	}
+	if !user.isAudience("test2") {
+		t.Error("return should not have been false")
+	}
 }
 
 func TestGetUserRoles(t *testing.T) {
-- 
GitLab