diff --git a/README.md b/README.md index 37f950bd08b09fee5888ca9f380657aa6f60fcb6..29c409893cd11615490a5acc59492c62b5d98488 100644 --- a/README.md +++ b/README.md @@ -274,6 +274,37 @@ Refresh tokens are either be stored as an encrypted cookie or placed (encrypted) Assuming access response responds with a refresh token and the --enable-refresh-token is true, the proxy will automatically refresh the access token for you. The tokens themselves are kept either as an encrypted (--encryption-key=KEY) cookie (cookie name: kc-state). Alternatively you can place the refresh token (still requires encryption key) in a local boltdb file or shared redis. Naturally the encryption key has to be the same on all instances and boltdb is for single instance only developments. +#### **Cross Origin Resource Sharing (CORS)** + +You are permitted to add CORS following headers into the /oauth uri namespace + + * Access-Control-Allow-Origin + * Access-Control-Allow-Methods + * Access-Control-Allow-Headers + * Access-Control-Expose-Headers + * Access-Control-Allow-Credentials + * Access-Control-Max-Age + +Either from the config file: + +```YAML +cors: + origins: + - '*' + methods: + - GET + - POST +``` + +or via the command line arguments + +```shell +--cors-origins [--cors-origins option] a set of origins to add to the CORS access control (Access-Control-Allow-Origin) +--cors-methods [--cors-methods option] the method permitted in the access control (Access-Control-Allow-Methods) +--cors-headers [--cors-headers option] a set of headers to add to the CORS access control (Access-Control-Allow-Headers) +--cors-exposes-headers [--cors-exposes-headers option] set the expose cors headers access control (Access-Control-Expose-Headers) +``` + #### **Endpoints** * **/oauth/authorize** is authentication endpoint which will generate the openid redirect to the provider diff --git a/doc.go b/doc.go index 7b1daf3926cd55e90af6323e96006e72c47d1644..c8eca14a8bd5d7d7e654de8e84147813b3f0a00c 100644 --- a/doc.go +++ b/doc.go @@ -34,13 +34,13 @@ const ( authorizationHeader = "Authorization" oauthURL = "/oauth" - authorizationURL = oauthURL + "/authorize" - callbackURL = oauthURL + "/callback" - healthURL = oauthURL + "/health" - tokenURL = oauthURL + "/token" - expiredURL = oauthURL + "/expired" - logoutURL = oauthURL + "/logout" - loginURL = oauthURL + "/login" + authorizationURL = "/authorize" + callbackURL = "/callback" + healthURL = "/health" + tokenURL = "/token" + expiredURL = "/expired" + logoutURL = "/logout" + loginURL = "/login" claimPreferredName = "preferred_username" claimAudience = "aud" diff --git a/handlers_utils.go b/handlers_utils.go index f360237ea08938de1841a1b7d61debdcea2a467c..d62e37f4b618b69fbfde4565606b76008397c4b9 100644 --- a/handlers_utils.go +++ b/handlers_utils.go @@ -72,6 +72,33 @@ func (r *oauthProxy) securityHandler() gin.HandlerFunc { } } +// +// crossSiteHandler injects the CORS headers, if set, for request made to /oauth +// +func (r *oauthProxy) crossSiteHandler() gin.HandlerFunc { + return func(cx *gin.Context) { + c := r.config.CORS + if len(c.Origins) > 0 { + cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ",")) + } + if len(c.Methods) > 0 { + cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ",")) + } + if len(c.Headers) > 0 { + cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ",")) + } + if len(c.ExposedHeaders) > 0 { + cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ",")) + } + if c.Credentials { + cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + } + if c.MaxAge > 0 { + cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds()))) + } + } +} + // // proxyHandler is responsible to proxy the requests on to the upstream endpoint // diff --git a/handlers_utils_test.go b/handlers_utils_test.go index 8fef188fd2a984720f14a67a428800c1dece650b..10abbb3b07065259e642d2fba3448cbe1afced8e 100644 --- a/handlers_utils_test.go +++ b/handlers_utils_test.go @@ -82,6 +82,54 @@ func TestExpirationHandler(t *testing.T) { } } +func TestCrossSiteHandler(t *testing.T) { + kc := newFakeKeycloakProxy(t) + handler := kc.crossSiteHandler() + + cases := []struct { + Cors *CORS + Headers map[string]string + }{ + { + Cors: &CORS{ + Origins: []string{"*"}, + }, + Headers: map[string]string{ + "Access-Control-Allow-Origin": "*", + }, + }, + { + Cors: &CORS{ + Origins: []string{"*", "https://examples.com"}, + Methods: []string{"GET"}, + }, + Headers: map[string]string{ + "Access-Control-Allow-Origin": "*,https://examples.com", + "Access-Control-Allow-Methods": "GET", + }, + }, + } + + for i, c := range cases { + // step: get the config + kc.config.CORS = c.Cors + // call the handler and check the responses + context := newFakeGinContext("GET", "/oauth/test") + handler(context) + // step: check the headers + for k, v := range c.Headers { + value := context.Writer.Header().Get(k) + if value == "" { + t.Errorf("case %d, should have had the %s header set, headers: %v", i, k, context.Writer.Header()) + continue + } + if value != v { + t.Errorf("case %d, expected: %s but got %s", i, k, value) + } + } + } +} + func TestSecurityHandler(t *testing.T) { kc := newFakeKeycloakProxy(t) handler := kc.securityHandler() diff --git a/server.go b/server.go index 0141ad14332e08a8e66b9bbcec64d4692250d15a..9395dd901623aba2c43099c45d1f3dfacdd63f46 100644 --- a/server.go +++ b/server.go @@ -166,7 +166,9 @@ func (r *oauthProxy) Run() error { err = server.ListenAndServeTLS(r.config.TLSCertificate, r.config.TLSPrivateKey) } if err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Fatalf("failed to start the service") + log.WithFields(log.Fields{ + "error": err.Error(), + }).Fatalf("failed to start the service") } }() @@ -177,9 +179,6 @@ func (r *oauthProxy) Run() error { // redirectToURL redirects the user and aborts the context // func (r *oauthProxy) redirectToURL(url string, cx *gin.Context) { - // step: add the cors headers - r.injectCORSHeaders(cx) - cx.Redirect(http.StatusTemporaryRedirect, url) cx.Abort() } @@ -217,32 +216,7 @@ func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) { return } - r.redirectToURL(authorizationURL+authQuery, cx) -} - -// -// injectCORSHeaders adds the cors access controls to the oauth responses -// -func (r *oauthProxy) injectCORSHeaders(cx *gin.Context) { - c := r.config.CORS - if len(c.Origins) > 0 { - cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ",")) - } - if len(c.Methods) > 0 { - cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ",")) - } - if len(c.Headers) > 0 { - cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ",")) - } - if len(c.ExposedHeaders) > 0 { - cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ",")) - } - if c.Credentials { - cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") - } - if c.MaxAge > 0 { - cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds()))) - } + r.redirectToURL(oauthURL+authorizationURL+authQuery, cx) } // @@ -279,13 +253,16 @@ func (r oauthProxy) setupRouter() error { r.router.Use(r.securityHandler()) } // step: add the routing - r.router.GET(authorizationURL, r.oauthAuthorizationHandler) - r.router.GET(callbackURL, r.oauthCallbackHandler) - r.router.GET(healthURL, r.healthHandler) - r.router.GET(tokenURL, r.tokenHandler) - r.router.GET(expiredURL, r.expirationHandler) - r.router.GET(logoutURL, r.logoutHandler) - r.router.POST(loginURL, r.loginHandler) + oauth := r.router.Group(oauthURL).Use(r.crossSiteHandler()) + { + oauth.GET(authorizationURL, r.oauthAuthorizationHandler) + oauth.GET(callbackURL, r.oauthCallbackHandler) + oauth.GET(healthURL, r.healthHandler) + oauth.GET(tokenURL, r.tokenHandler) + oauth.GET(expiredURL, r.expirationHandler) + oauth.GET(logoutURL, r.logoutHandler) + oauth.POST(loginURL, r.loginHandler) + } r.router.Use(r.entryPointHandler(), r.authenticationHandler(), r.admissionHandler()) @@ -302,6 +279,7 @@ func (r *oauthProxy) setupTemplates() error { log.Debugf("loading the custom sign in page: %s", r.config.SignInPage) list = append(list, r.config.SignInPage) } + if r.config.ForbiddenPage != "" { log.Debugf("loading the custom sign forbidden page: %s", r.config.ForbiddenPage) list = append(list, r.config.ForbiddenPage) diff --git a/store_boltdb.go b/store_boltdb.go index f2e813ff542cf0bd50e9c9dec0797073a36b9230..d3333bcda89476dfb495a71a3fbe4ce1ca5bbc96 100644 --- a/store_boltdb.go +++ b/store_boltdb.go @@ -30,6 +30,7 @@ const ( ) var ( + // ErrNoBoltdbBucket means the bucket does not exist ErrNoBoltdbBucket = errors.New("the boltdb bucket does not exists") ) @@ -54,8 +55,8 @@ func newBoltDBStore(location *url.URL) (Store, error) { // step: create the bucket err = db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucketIfNotExists([]byte(dbName)) - return err + _, e := tx.CreateBucketIfNotExists([]byte(dbName)) + return e }) return &boltdbStore{ diff --git a/util_test.go b/util_test.go index 628faf9ae833b135d2a64b29942976c42473a4b6..30927a54306107d1b0f2f55a45a7aab4a081d4a5 100644 --- a/util_test.go +++ b/util_test.go @@ -108,7 +108,7 @@ func TestDecodeText(t *testing.T) { } assert.NotEmpty(t, encrypted) - decoded, err := decodeText(encrypted, fakeKey) + decoded, _ := decodeText(encrypted, fakeKey) assert.NotNil(t, decoded, "the session should not have been nil") assert.Equal(t, decoded, fakeText, "the decoded text is not the same") }