Skip to content
Snippets Groups Projects
Commit cde9b381 authored by Rohith's avatar Rohith
Browse files

- adding the cors handler as a middle handler for all /oauth (#49)

parent bf5686bf
No related branches found
No related tags found
No related merge requests found
......@@ -274,6 +274,37 @@ Refresh tokens are either be stored as an encrypted cookie or placed (encrypted)
Assuming access response responds with a refresh token and the --enable-refresh-token is true, the proxy will automatically refresh the access token for you. The tokens themselves are kept either as an encrypted (--encryption-key=KEY) cookie (cookie name: kc-state). Alternatively you can place the refresh token (still requires encryption key) in a local boltdb file or shared redis. Naturally the encryption key has to be the same on all instances and boltdb is for single instance only developments.
#### **Cross Origin Resource Sharing (CORS)**
You are permitted to add CORS following headers into the /oauth uri namespace
* Access-Control-Allow-Origin
* Access-Control-Allow-Methods
* Access-Control-Allow-Headers
* Access-Control-Expose-Headers
* Access-Control-Allow-Credentials
* Access-Control-Max-Age
Either from the config file:
```YAML
cors:
origins:
- '*'
methods:
- GET
- POST
```
or via the command line arguments
```shell
--cors-origins [--cors-origins option] a set of origins to add to the CORS access control (Access-Control-Allow-Origin)
--cors-methods [--cors-methods option] the method permitted in the access control (Access-Control-Allow-Methods)
--cors-headers [--cors-headers option] a set of headers to add to the CORS access control (Access-Control-Allow-Headers)
--cors-exposes-headers [--cors-exposes-headers option] set the expose cors headers access control (Access-Control-Expose-Headers)
```
#### **Endpoints**
* **/oauth/authorize** is authentication endpoint which will generate the openid redirect to the provider
......
......@@ -34,13 +34,13 @@ const (
authorizationHeader = "Authorization"
oauthURL = "/oauth"
authorizationURL = oauthURL + "/authorize"
callbackURL = oauthURL + "/callback"
healthURL = oauthURL + "/health"
tokenURL = oauthURL + "/token"
expiredURL = oauthURL + "/expired"
logoutURL = oauthURL + "/logout"
loginURL = oauthURL + "/login"
authorizationURL = "/authorize"
callbackURL = "/callback"
healthURL = "/health"
tokenURL = "/token"
expiredURL = "/expired"
logoutURL = "/logout"
loginURL = "/login"
claimPreferredName = "preferred_username"
claimAudience = "aud"
......
......@@ -72,6 +72,33 @@ func (r *oauthProxy) securityHandler() gin.HandlerFunc {
}
}
//
// crossSiteHandler injects the CORS headers, if set, for request made to /oauth
//
func (r *oauthProxy) crossSiteHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
c := r.config.CORS
if len(c.Origins) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ","))
}
if len(c.Methods) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ","))
}
if len(c.Headers) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ","))
}
if len(c.ExposedHeaders) > 0 {
cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ","))
}
if c.Credentials {
cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if c.MaxAge > 0 {
cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds())))
}
}
}
//
// proxyHandler is responsible to proxy the requests on to the upstream endpoint
//
......
......@@ -82,6 +82,54 @@ func TestExpirationHandler(t *testing.T) {
}
}
func TestCrossSiteHandler(t *testing.T) {
kc := newFakeKeycloakProxy(t)
handler := kc.crossSiteHandler()
cases := []struct {
Cors *CORS
Headers map[string]string
}{
{
Cors: &CORS{
Origins: []string{"*"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*",
},
},
{
Cors: &CORS{
Origins: []string{"*", "https://examples.com"},
Methods: []string{"GET"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*,https://examples.com",
"Access-Control-Allow-Methods": "GET",
},
},
}
for i, c := range cases {
// step: get the config
kc.config.CORS = c.Cors
// call the handler and check the responses
context := newFakeGinContext("GET", "/oauth/test")
handler(context)
// step: check the headers
for k, v := range c.Headers {
value := context.Writer.Header().Get(k)
if value == "" {
t.Errorf("case %d, should have had the %s header set, headers: %v", i, k, context.Writer.Header())
continue
}
if value != v {
t.Errorf("case %d, expected: %s but got %s", i, k, value)
}
}
}
}
func TestSecurityHandler(t *testing.T) {
kc := newFakeKeycloakProxy(t)
handler := kc.securityHandler()
......
......@@ -166,7 +166,9 @@ func (r *oauthProxy) Run() error {
err = server.ListenAndServeTLS(r.config.TLSCertificate, r.config.TLSPrivateKey)
}
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Fatalf("failed to start the service")
log.WithFields(log.Fields{
"error": err.Error(),
}).Fatalf("failed to start the service")
}
}()
......@@ -177,9 +179,6 @@ func (r *oauthProxy) Run() error {
// redirectToURL redirects the user and aborts the context
//
func (r *oauthProxy) redirectToURL(url string, cx *gin.Context) {
// step: add the cors headers
r.injectCORSHeaders(cx)
cx.Redirect(http.StatusTemporaryRedirect, url)
cx.Abort()
}
......@@ -217,32 +216,7 @@ func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) {
return
}
r.redirectToURL(authorizationURL+authQuery, cx)
}
//
// injectCORSHeaders adds the cors access controls to the oauth responses
//
func (r *oauthProxy) injectCORSHeaders(cx *gin.Context) {
c := r.config.CORS
if len(c.Origins) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ","))
}
if len(c.Methods) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ","))
}
if len(c.Headers) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ","))
}
if len(c.ExposedHeaders) > 0 {
cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ","))
}
if c.Credentials {
cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if c.MaxAge > 0 {
cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds())))
}
r.redirectToURL(oauthURL+authorizationURL+authQuery, cx)
}
//
......@@ -279,13 +253,16 @@ func (r oauthProxy) setupRouter() error {
r.router.Use(r.securityHandler())
}
// step: add the routing
r.router.GET(authorizationURL, r.oauthAuthorizationHandler)
r.router.GET(callbackURL, r.oauthCallbackHandler)
r.router.GET(healthURL, r.healthHandler)
r.router.GET(tokenURL, r.tokenHandler)
r.router.GET(expiredURL, r.expirationHandler)
r.router.GET(logoutURL, r.logoutHandler)
r.router.POST(loginURL, r.loginHandler)
oauth := r.router.Group(oauthURL).Use(r.crossSiteHandler())
{
oauth.GET(authorizationURL, r.oauthAuthorizationHandler)
oauth.GET(callbackURL, r.oauthCallbackHandler)
oauth.GET(healthURL, r.healthHandler)
oauth.GET(tokenURL, r.tokenHandler)
oauth.GET(expiredURL, r.expirationHandler)
oauth.GET(logoutURL, r.logoutHandler)
oauth.POST(loginURL, r.loginHandler)
}
r.router.Use(r.entryPointHandler(), r.authenticationHandler(), r.admissionHandler())
......@@ -302,6 +279,7 @@ func (r *oauthProxy) setupTemplates() error {
log.Debugf("loading the custom sign in page: %s", r.config.SignInPage)
list = append(list, r.config.SignInPage)
}
if r.config.ForbiddenPage != "" {
log.Debugf("loading the custom sign forbidden page: %s", r.config.ForbiddenPage)
list = append(list, r.config.ForbiddenPage)
......
......@@ -30,6 +30,7 @@ const (
)
var (
// ErrNoBoltdbBucket means the bucket does not exist
ErrNoBoltdbBucket = errors.New("the boltdb bucket does not exists")
)
......@@ -54,8 +55,8 @@ func newBoltDBStore(location *url.URL) (Store, error) {
// step: create the bucket
err = db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(dbName))
return err
_, e := tx.CreateBucketIfNotExists([]byte(dbName))
return e
})
return &boltdbStore{
......
......@@ -108,7 +108,7 @@ func TestDecodeText(t *testing.T) {
}
assert.NotEmpty(t, encrypted)
decoded, err := decodeText(encrypted, fakeKey)
decoded, _ := decodeText(encrypted, fakeKey)
assert.NotNil(t, decoded, "the session should not have been nil")
assert.Equal(t, decoded, fakeText, "the decoded text is not the same")
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment