From cdfb2a93085c65f7ad6781bc11136a018347305d Mon Sep 17 00:00:00 2001
From: Philippe Gagnon <philgagnon1@gmail.com>
Date: Sat, 16 Feb 2019 18:15:03 -0500
Subject: [PATCH] Add base URI to OAuth routes and set cookies path to Base URI

---
 cookies.go         |  6 +++++-
 cookies_test.go    | 39 +++++++++++++++++++++++++++++++++++++++
 handlers.go        |  4 ----
 middleware_test.go | 38 ++++++++++++++++++++++++++++++++++++++
 server.go          |  2 +-
 5 files changed, 83 insertions(+), 6 deletions(-)

diff --git a/cookies.go b/cookies.go
index 8bc82a4..64dddc4 100644
--- a/cookies.go
+++ b/cookies.go
@@ -32,11 +32,15 @@ func (r *oauthProxy) dropCookie(w http.ResponseWriter, host, name, value string,
 	if r.config.CookieDomain != "" {
 		domain = r.config.CookieDomain
 	}
+	path := r.config.BaseURI
+	if path == "" {
+		path = "/"
+	}
 	cookie := &http.Cookie{
 		Domain:   domain,
 		HttpOnly: r.config.HTTPOnlyCookie,
 		Name:     name,
-		Path:     "/",
+		Path:     path,
 		Secure:   r.config.SecureCookie,
 		Value:    value,
 	}
diff --git a/cookies_test.go b/cookies_test.go
index 23d4e67..6fb5f0a 100644
--- a/cookies_test.go
+++ b/cookies_test.go
@@ -40,6 +40,45 @@ func TestCookieDomainHostHeader(t *testing.T) {
 	assert.Equal(t, cookie.Domain, "127.0.0.1")
 }
 
+func TestCookieBasePath(t *testing.T) {
+	cfg := newFakeKeycloakConfig()
+	cfg.BaseURI = "/base-uri"
+
+	_, _, svc := newTestProxyService(cfg)
+
+	resp, err := makeTestCodeFlowLogin(svc + "/admin")
+	assert.NoError(t, err)
+	assert.NotNil(t, resp)
+
+	var cookie *http.Cookie
+	for _, c := range resp.Cookies() {
+		if c.Name == "kc-access" {
+			cookie = c
+		}
+	}
+	assert.NotNil(t, cookie)
+	assert.Equal(t, "/base-uri", cookie.Path)
+}
+
+func TestCookieWithoutBasePath(t *testing.T) {
+	cfg := newFakeKeycloakConfig()
+
+	_, _, svc := newTestProxyService(cfg)
+
+	resp, err := makeTestCodeFlowLogin(svc + "/admin")
+	assert.NoError(t, err)
+	assert.NotNil(t, resp)
+
+	var cookie *http.Cookie
+	for _, c := range resp.Cookies() {
+		if c.Name == "kc-access" {
+			cookie = c
+		}
+	}
+	assert.NotNil(t, cookie)
+	assert.Equal(t, "/", cookie.Path)
+}
+
 func TestCookieDomain(t *testing.T) {
 	p, _, svc := newTestProxyService(nil)
 	p.config.CookieDomain = "domain.com"
diff --git a/handlers.go b/handlers.go
index 5245900..0c5520e 100644
--- a/handlers.go
+++ b/handlers.go
@@ -210,10 +210,6 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque
 			redirectURI = string(decoded)
 		}
 	}
-	if r.config.BaseURI != "" {
-		// assuming state starts with slash
-		redirectURI = r.config.BaseURI + redirectURI
-	}
 
 	r.redirectToURL(redirectURI, w, req, http.StatusTemporaryRedirect)
 }
diff --git a/middleware_test.go b/middleware_test.go
index 7cf62e9..9d7e1a9 100644
--- a/middleware_test.go
+++ b/middleware_test.go
@@ -354,6 +354,44 @@ func TestOauthRequests(t *testing.T) {
 	newFakeProxy(cfg).RunTests(t, requests)
 }
 
+func TestOauthRequestsWithBaseURI(t *testing.T) {
+	cfg := newFakeKeycloakConfig()
+	cfg.BaseURI = "/base-uri"
+	requests := []fakeRequest{
+		{
+			URI:          "/base-uri/oauth/authorize",
+			Redirects:    true,
+			ExpectedCode: http.StatusTemporaryRedirect,
+		},
+		{
+			URI:          "/base-uri/oauth/callback",
+			Redirects:    true,
+			ExpectedCode: http.StatusBadRequest,
+		},
+		{
+			URI:          "/base-uri/oauth/health",
+			Redirects:    true,
+			ExpectedCode: http.StatusOK,
+		},
+		{
+			URI:           "/oauth/authorize",
+			ExpectedProxy: true,
+			ExpectedCode:  http.StatusOK,
+		},
+		{
+			URI:           "/oauth/callback",
+			ExpectedProxy: true,
+			ExpectedCode:  http.StatusOK,
+		},
+		{
+			URI:           "/oauth/health",
+			ExpectedProxy: true,
+			ExpectedCode:  http.StatusOK,
+		},
+	}
+	newFakeProxy(cfg).RunTests(t, requests)
+}
+
 func TestMethodExclusions(t *testing.T) {
 	cfg := newFakeKeycloakConfig()
 	cfg.Resources = []*Resource{
diff --git a/server.go b/server.go
index fc308e6..02ad29c 100644
--- a/server.go
+++ b/server.go
@@ -197,7 +197,7 @@ func (r *oauthProxy) createReverseProxy() error {
 	}
 
 	// step: add the routing for oauth
-	engine.With(proxyDenyMiddleware).Route(r.config.OAuthURI, func(e chi.Router) {
+	engine.With(proxyDenyMiddleware).Route(r.config.BaseURI+r.config.OAuthURI, func(e chi.Router) {
 		e.MethodNotAllowed(methodNotAllowHandlder)
 		e.HandleFunc(authorizationURL, r.oauthAuthorizationHandler)
 		e.Get(callbackURL, r.oauthCallbackHandler)
-- 
GitLab