Skip to content
Snippets Groups Projects
Commit cde9b381 authored by Rohith's avatar Rohith
Browse files

- adding the cors handler as a middle handler for all /oauth (#49)

parent bf5686bf
No related branches found
No related tags found
No related merge requests found
...@@ -274,6 +274,37 @@ Refresh tokens are either be stored as an encrypted cookie or placed (encrypted) ...@@ -274,6 +274,37 @@ Refresh tokens are either be stored as an encrypted cookie or placed (encrypted)
Assuming access response responds with a refresh token and the --enable-refresh-token is true, the proxy will automatically refresh the access token for you. The tokens themselves are kept either as an encrypted (--encryption-key=KEY) cookie (cookie name: kc-state). Alternatively you can place the refresh token (still requires encryption key) in a local boltdb file or shared redis. Naturally the encryption key has to be the same on all instances and boltdb is for single instance only developments. Assuming access response responds with a refresh token and the --enable-refresh-token is true, the proxy will automatically refresh the access token for you. The tokens themselves are kept either as an encrypted (--encryption-key=KEY) cookie (cookie name: kc-state). Alternatively you can place the refresh token (still requires encryption key) in a local boltdb file or shared redis. Naturally the encryption key has to be the same on all instances and boltdb is for single instance only developments.
#### **Cross Origin Resource Sharing (CORS)**
You are permitted to add CORS following headers into the /oauth uri namespace
* Access-Control-Allow-Origin
* Access-Control-Allow-Methods
* Access-Control-Allow-Headers
* Access-Control-Expose-Headers
* Access-Control-Allow-Credentials
* Access-Control-Max-Age
Either from the config file:
```YAML
cors:
origins:
- '*'
methods:
- GET
- POST
```
or via the command line arguments
```shell
--cors-origins [--cors-origins option] a set of origins to add to the CORS access control (Access-Control-Allow-Origin)
--cors-methods [--cors-methods option] the method permitted in the access control (Access-Control-Allow-Methods)
--cors-headers [--cors-headers option] a set of headers to add to the CORS access control (Access-Control-Allow-Headers)
--cors-exposes-headers [--cors-exposes-headers option] set the expose cors headers access control (Access-Control-Expose-Headers)
```
#### **Endpoints** #### **Endpoints**
* **/oauth/authorize** is authentication endpoint which will generate the openid redirect to the provider * **/oauth/authorize** is authentication endpoint which will generate the openid redirect to the provider
......
...@@ -34,13 +34,13 @@ const ( ...@@ -34,13 +34,13 @@ const (
authorizationHeader = "Authorization" authorizationHeader = "Authorization"
oauthURL = "/oauth" oauthURL = "/oauth"
authorizationURL = oauthURL + "/authorize" authorizationURL = "/authorize"
callbackURL = oauthURL + "/callback" callbackURL = "/callback"
healthURL = oauthURL + "/health" healthURL = "/health"
tokenURL = oauthURL + "/token" tokenURL = "/token"
expiredURL = oauthURL + "/expired" expiredURL = "/expired"
logoutURL = oauthURL + "/logout" logoutURL = "/logout"
loginURL = oauthURL + "/login" loginURL = "/login"
claimPreferredName = "preferred_username" claimPreferredName = "preferred_username"
claimAudience = "aud" claimAudience = "aud"
......
...@@ -72,6 +72,33 @@ func (r *oauthProxy) securityHandler() gin.HandlerFunc { ...@@ -72,6 +72,33 @@ func (r *oauthProxy) securityHandler() gin.HandlerFunc {
} }
} }
//
// crossSiteHandler injects the CORS headers, if set, for request made to /oauth
//
func (r *oauthProxy) crossSiteHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
c := r.config.CORS
if len(c.Origins) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ","))
}
if len(c.Methods) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ","))
}
if len(c.Headers) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ","))
}
if len(c.ExposedHeaders) > 0 {
cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ","))
}
if c.Credentials {
cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if c.MaxAge > 0 {
cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds())))
}
}
}
// //
// proxyHandler is responsible to proxy the requests on to the upstream endpoint // proxyHandler is responsible to proxy the requests on to the upstream endpoint
// //
......
...@@ -82,6 +82,54 @@ func TestExpirationHandler(t *testing.T) { ...@@ -82,6 +82,54 @@ func TestExpirationHandler(t *testing.T) {
} }
} }
func TestCrossSiteHandler(t *testing.T) {
kc := newFakeKeycloakProxy(t)
handler := kc.crossSiteHandler()
cases := []struct {
Cors *CORS
Headers map[string]string
}{
{
Cors: &CORS{
Origins: []string{"*"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*",
},
},
{
Cors: &CORS{
Origins: []string{"*", "https://examples.com"},
Methods: []string{"GET"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*,https://examples.com",
"Access-Control-Allow-Methods": "GET",
},
},
}
for i, c := range cases {
// step: get the config
kc.config.CORS = c.Cors
// call the handler and check the responses
context := newFakeGinContext("GET", "/oauth/test")
handler(context)
// step: check the headers
for k, v := range c.Headers {
value := context.Writer.Header().Get(k)
if value == "" {
t.Errorf("case %d, should have had the %s header set, headers: %v", i, k, context.Writer.Header())
continue
}
if value != v {
t.Errorf("case %d, expected: %s but got %s", i, k, value)
}
}
}
}
func TestSecurityHandler(t *testing.T) { func TestSecurityHandler(t *testing.T) {
kc := newFakeKeycloakProxy(t) kc := newFakeKeycloakProxy(t)
handler := kc.securityHandler() handler := kc.securityHandler()
......
...@@ -166,7 +166,9 @@ func (r *oauthProxy) Run() error { ...@@ -166,7 +166,9 @@ func (r *oauthProxy) Run() error {
err = server.ListenAndServeTLS(r.config.TLSCertificate, r.config.TLSPrivateKey) err = server.ListenAndServeTLS(r.config.TLSCertificate, r.config.TLSPrivateKey)
} }
if err != nil { if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Fatalf("failed to start the service") log.WithFields(log.Fields{
"error": err.Error(),
}).Fatalf("failed to start the service")
} }
}() }()
...@@ -177,9 +179,6 @@ func (r *oauthProxy) Run() error { ...@@ -177,9 +179,6 @@ func (r *oauthProxy) Run() error {
// redirectToURL redirects the user and aborts the context // redirectToURL redirects the user and aborts the context
// //
func (r *oauthProxy) redirectToURL(url string, cx *gin.Context) { func (r *oauthProxy) redirectToURL(url string, cx *gin.Context) {
// step: add the cors headers
r.injectCORSHeaders(cx)
cx.Redirect(http.StatusTemporaryRedirect, url) cx.Redirect(http.StatusTemporaryRedirect, url)
cx.Abort() cx.Abort()
} }
...@@ -217,32 +216,7 @@ func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) { ...@@ -217,32 +216,7 @@ func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) {
return return
} }
r.redirectToURL(authorizationURL+authQuery, cx) r.redirectToURL(oauthURL+authorizationURL+authQuery, cx)
}
//
// injectCORSHeaders adds the cors access controls to the oauth responses
//
func (r *oauthProxy) injectCORSHeaders(cx *gin.Context) {
c := r.config.CORS
if len(c.Origins) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ","))
}
if len(c.Methods) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ","))
}
if len(c.Headers) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ","))
}
if len(c.ExposedHeaders) > 0 {
cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ","))
}
if c.Credentials {
cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if c.MaxAge > 0 {
cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds())))
}
} }
// //
...@@ -279,13 +253,16 @@ func (r oauthProxy) setupRouter() error { ...@@ -279,13 +253,16 @@ func (r oauthProxy) setupRouter() error {
r.router.Use(r.securityHandler()) r.router.Use(r.securityHandler())
} }
// step: add the routing // step: add the routing
r.router.GET(authorizationURL, r.oauthAuthorizationHandler) oauth := r.router.Group(oauthURL).Use(r.crossSiteHandler())
r.router.GET(callbackURL, r.oauthCallbackHandler) {
r.router.GET(healthURL, r.healthHandler) oauth.GET(authorizationURL, r.oauthAuthorizationHandler)
r.router.GET(tokenURL, r.tokenHandler) oauth.GET(callbackURL, r.oauthCallbackHandler)
r.router.GET(expiredURL, r.expirationHandler) oauth.GET(healthURL, r.healthHandler)
r.router.GET(logoutURL, r.logoutHandler) oauth.GET(tokenURL, r.tokenHandler)
r.router.POST(loginURL, r.loginHandler) oauth.GET(expiredURL, r.expirationHandler)
oauth.GET(logoutURL, r.logoutHandler)
oauth.POST(loginURL, r.loginHandler)
}
r.router.Use(r.entryPointHandler(), r.authenticationHandler(), r.admissionHandler()) r.router.Use(r.entryPointHandler(), r.authenticationHandler(), r.admissionHandler())
...@@ -302,6 +279,7 @@ func (r *oauthProxy) setupTemplates() error { ...@@ -302,6 +279,7 @@ func (r *oauthProxy) setupTemplates() error {
log.Debugf("loading the custom sign in page: %s", r.config.SignInPage) log.Debugf("loading the custom sign in page: %s", r.config.SignInPage)
list = append(list, r.config.SignInPage) list = append(list, r.config.SignInPage)
} }
if r.config.ForbiddenPage != "" { if r.config.ForbiddenPage != "" {
log.Debugf("loading the custom sign forbidden page: %s", r.config.ForbiddenPage) log.Debugf("loading the custom sign forbidden page: %s", r.config.ForbiddenPage)
list = append(list, r.config.ForbiddenPage) list = append(list, r.config.ForbiddenPage)
......
...@@ -30,6 +30,7 @@ const ( ...@@ -30,6 +30,7 @@ const (
) )
var ( var (
// ErrNoBoltdbBucket means the bucket does not exist
ErrNoBoltdbBucket = errors.New("the boltdb bucket does not exists") ErrNoBoltdbBucket = errors.New("the boltdb bucket does not exists")
) )
...@@ -54,8 +55,8 @@ func newBoltDBStore(location *url.URL) (Store, error) { ...@@ -54,8 +55,8 @@ func newBoltDBStore(location *url.URL) (Store, error) {
// step: create the bucket // step: create the bucket
err = db.Update(func(tx *bolt.Tx) error { err = db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(dbName)) _, e := tx.CreateBucketIfNotExists([]byte(dbName))
return err return e
}) })
return &boltdbStore{ return &boltdbStore{
......
...@@ -108,7 +108,7 @@ func TestDecodeText(t *testing.T) { ...@@ -108,7 +108,7 @@ func TestDecodeText(t *testing.T) {
} }
assert.NotEmpty(t, encrypted) assert.NotEmpty(t, encrypted)
decoded, err := decodeText(encrypted, fakeKey) decoded, _ := decodeText(encrypted, fakeKey)
assert.NotNil(t, decoded, "the session should not have been nil") assert.NotNil(t, decoded, "the session should not have been nil")
assert.Equal(t, decoded, fakeText, "the decoded text is not the same") assert.Equal(t, decoded, fakeText, "the decoded text is not the same")
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment