diff --git a/.gitignore b/.gitignore index 953cc9940f52f35590236584f0ab2bcc16da2e7f..e3d5ef7c57181d5bed8823dbe994ba849d936e0a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ bin/ release/ cover.html cover.out +tests/db.bolt +test.sock +tests/redis.conf *.iml config.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index 474537eb58d9d12c95506019922ba661f2b06b09..cdf0165bc4e9e291483f62fa0c0d85b79e9bb826 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,18 @@ + +#### **1.0.3 (April 30th, 2016)** + +FIXES: + * Fixes the cookie sessions expiraton + +FEATURES: + * Adding a idle duration configuration option which controls the expiration of access token cookie and thus session. + If the session is not used within that period, the session is removed. + * The upstream endpoint has also be a unix socket + +BREAKING CHANGES: + * Change the client id in json/yaml config file from clientid -> client-id + #### **1.0.2 (April 22th, 2016)** FIXES: diff --git a/Makefile b/Makefile index b5d93df8aa49f24280edb2b57968ac80503d4b53..38232ad9132c1e4a240be994a81ca44e9fcd2a24 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ NAME=keycloak-proxy AUTHOR=gambol99 -HARDWARE=$(shell uname -m) REGISTRY=docker.io GOVERSION=1.6.0 SUDO=sudo -GIT_COMMIT=$(shell git log --pretty=format:'%h' -n 1) ROOT_DIR=${PWD} +HARDWARE=$(shell uname -m) +GIT_SHA=$(shell git --no-pager describe --tags --always --dirty) VERSION=$(shell awk '/version.*=/ { print $$3 }' doc.go | sed 's/"//g') DEPS=$(shell go list -f '{{range .TestImports}}{{.}} {{end}}' ./...) PACKAGES=$(shell go list ./...) @@ -20,12 +20,15 @@ golang: @echo "--> Go Version" @go version -build: +version: + @sed -i "s/const gitSHA =.*/const gitSHA = \"${GIT_SHA}\"/" doc.go + +build: version @echo "--> Compiling the project" mkdir -p bin godep go build -o bin/${NAME} -static: golang deps +static: version golang deps @echo "--> Compiling the static binary" mkdir -p bin CGO_ENABLED=0 GOOS=linux godep go build -a -tags netgo -ldflags '-w' -o bin/${NAME} diff --git a/config.go b/config.go index e043e0d2e98d87da56e6ed6a3e30d9ea0eddeb62..755435061f6b97da4be336577321fcc5ca5880dc 100644 --- a/config.go +++ b/config.go @@ -31,13 +31,13 @@ import ( // newDefaultConfig returns a initialized config func newDefaultConfig() *Config { return &Config{ - Listen: "127.0.0.1:3000", - RedirectionURL: "http://127.0.0.1:3000", - Upstream: "http://127.0.0.1:8081", - TagData: make(map[string]string, 0), - ClaimsMatch: make(map[string]string, 0), - Header: make(map[string]string, 0), - CORS: &CORS{}, + Listen: "127.0.0.1:3000", + RedirectionURL: "http://127.0.0.1:3000", + Upstream: "http://127.0.0.1:8081", + TagData: make(map[string]string, 0), + ClaimsMatch: make(map[string]string, 0), + Header: make(map[string]string, 0), + CrossOrigin: CORS{}, SkipUpstreamTLSVerify: true, } } @@ -155,12 +155,18 @@ func readOptions(cx *cli.Context, config *Config) (err error) { if cx.IsSet("upstream-keepalives") { config.Keepalives = cx.Bool("upstream-keepalives") } + if cx.IsSet("idle-duration") { + config.IdleDuration = cx.Duration("idle-duration") + } if cx.IsSet("skip-token-verification") { config.SkipTokenVerification = cx.Bool("skip-token-verification") } if cx.IsSet("skip-upstream-tls-verify") { config.SkipUpstreamTLSVerify = cx.Bool("skip-upstream-tls-verify") } + if cx.IsSet("enable-refresh-tokens") { + config.EnableRefreshTokens = cx.Bool("enable-refresh-tokens") + } if cx.IsSet("encryption-key") { config.EncryptionKey = cx.String("encryption-key") } @@ -210,22 +216,22 @@ func readOptions(cx *cli.Context, config *Config) (err error) { config.Hostnames = cx.StringSlice("hostname") } if cx.IsSet("cors-origins") { - config.CORS.Origins = cx.StringSlice("cors-origins") + config.CrossOrigin.Origins = cx.StringSlice("cors-origins") } if cx.IsSet("cors-methods") { - config.CORS.Methods = cx.StringSlice("cors-methods") + config.CrossOrigin.Methods = cx.StringSlice("cors-methods") } if cx.IsSet("cors-headers") { - config.CORS.Headers = cx.StringSlice("cors-headers") + config.CrossOrigin.Headers = cx.StringSlice("cors-headers") } if cx.IsSet("cors-exposed-headers") { - config.CORS.ExposedHeaders = cx.StringSlice("cors-exposed-headers") + config.CrossOrigin.ExposedHeaders = cx.StringSlice("cors-exposed-headers") } if cx.IsSet("cors-max-age") { - config.CORS.MaxAge = cx.Duration("cors-max-age") + config.CrossOrigin.MaxAge = cx.Duration("cors-max-age") } if cx.IsSet("cors-credentials") { - config.CORS.Credentials = cx.BoolT("cors-credentials") + config.CrossOrigin.Credentials = cx.BoolT("cors-credentials") } if cx.IsSet("tag") { config.TagData, err = decodeKeyPairs(cx.StringSlice("tag")) @@ -302,6 +308,10 @@ func getOptions() []cli.Flag { Name: "discovery-url", Usage: "the discovery url to retrieve the openid configuration", }, + cli.DurationFlag{ + Name: "idle-duration", + Usage: "the expiration of the access token cookie, if not used within this time its removed", + }, cli.StringFlag{ Name: "upstream-url", Usage: "the url for the upstream endpoint you wish to proxy to", @@ -309,12 +319,16 @@ func getOptions() []cli.Flag { }, cli.StringFlag{ Name: "revocation-url", - Usage: "the url for the revocation endpoint to revoke refresh token, not all providers support the revocation_endpoint", + Usage: "the url for the revocation endpoint to revoke refresh token", Value: "/oauth2/revoke", }, cli.BoolTFlag{ Name: "upstream-keepalives", - Usage: "enables or disables the keepalive connections for upstream endpoint (defaults true)", + Usage: "enables or disables the keepalive connections for upstream endpoint", + }, + cli.BoolFlag{ + Name: "enable-refresh-tokens", + Usage: "enables the handling of the refresh tokens", }, cli.StringFlag{ Name: "encryption-key", @@ -322,15 +336,15 @@ func getOptions() []cli.Flag { }, cli.StringFlag{ Name: "store-url", - Usage: "the store url to use for storing the refresh tokens, i.e. redis://127.0.0.1:6379, file:///etc/tokens.file", + Usage: "url for the storage subsystem, e.g redis://127.0.0.1:6379, file:///etc/tokens.file", }, cli.BoolFlag{ Name: "no-redirects", - Usage: "do not have back redirects when no authentication is present, simple reply with 401 code", + Usage: "do not have back redirects when no authentication is present, 401 them", }, cli.StringFlag{ Name: "redirection-url", - Usage: fmt.Sprintf("the redirection url, namely the site url, note: %s will be added to it", oauthURL), + Usage: fmt.Sprintf("redirection url for the oauth callback url (%s is added)", oauthURL), }, cli.StringSliceFlag{ Name: "hostname", @@ -358,7 +372,7 @@ func getOptions() []cli.Flag { }, cli.StringSliceFlag{ Name: "claim", - Usage: "a series of key pair values which must match the claims in the token present e.g. aud=myapp, iss=http://example.com etcd", + Usage: "keypair values for matching access token claims e.g. aud=myapp, iss=http://example.*", }, cli.StringSliceFlag{ Name: "resource", @@ -374,11 +388,11 @@ func getOptions() []cli.Flag { }, cli.StringSliceFlag{ Name: "tag", - Usage: "a keypair tag which is passed to the templates when render, i.e. title='My Page',site='my name' etc", + Usage: "keypair's passed to the templates at render,e.g title='My Page'", }, cli.StringSliceFlag{ Name: "cors-origins", - Usage: "a set of origins to add to the CORS access control (Access-Control-Allow-Origin)", + Usage: "list of origins to add to the CORE origins control (Access-Control-Allow-Origin)", }, cli.StringSliceFlag{ Name: "cors-methods", @@ -406,11 +420,7 @@ func getOptions() []cli.Flag { }, cli.BoolFlag{ Name: "skip-token-verification", - Usage: "testing purposes ONLY, the option allows you to bypass the token verification, expiration and roles are still enforced", - }, - cli.BoolFlag{ - Name: "proxy-protocol", - Usage: "switches on proxy protocol support on the listen (not supported yet)", + Usage: "TESTING ONLY; bypass's token verification, expiration and roles enforced", }, cli.BoolFlag{ Name: "offline-session", diff --git a/config_sample.yml b/config_sample.yml index 409d63d37344357251351334b6e440b73a8c67a5..1471da3235f2ee66c171d702fc0a150e842d444e 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -3,13 +3,15 @@ # is the url for retrieve the openid configuration - normally the <server>/auth/realm/<realm_name> discovery-url: https://keycloak.example.com/auth/realms/commons # the client id for the 'client' application -clientid: <CLIENT_ID> +client-id: <CLIENT_ID> # the secret associated to the 'client' application client-secret: <CLIENT_SECRET> # the interface definition you wish the proxy to listen, all interfaces is specified as ':<port>' listen: 127.0.0.1:3000 # whether to request offline access and use a refresh token enable-refresh-tokens: true +# the max amount of time a session can stay alive without being used +idle-duration: 24h # log all incoming requests log-requests: true # log in json format @@ -31,8 +33,7 @@ upstream: http://127.0.0.1:80 # upstream-keepalives specified wheather you want keepalive on the upstream endpoint upstream-keepalives: true # additional scopes to add to add to the default (openid+email+profile) -scopes: - - vpn-user +scopes: [] # enables a more extra secuirty features enable-security-filter: true # a map of claims that MUST exist in the token presented and the value is it MUST match diff --git a/config_test.go b/config_test.go index 450b97c1c7febe7fde60c228a54fe8ab81d3925f..06c33d6d6a6bf634faceaa9f6f2bd5ce4f47a615 100644 --- a/config_test.go +++ b/config_test.go @@ -37,14 +37,14 @@ func TestReadConfiguration(t *testing.T) { { Content: ` discovery_url: https://keyclock.domain.com/ -clientid: <client_id> +client-id: <client_id> secret: <secret> `, }, { Content: ` discovery_url: https://keyclock.domain.com -clientid: <client_id> +client-id: <client_id> secret: <secret> upstream: http://127.0.0.1:8080 redirection_url: http://127.0.0.1:3000 diff --git a/cookies.go b/cookies.go index 643d8ba45b33921cea6c44ef040f1f4506375fcf..a28cb12fe38c599efec4bc5bd38eab77662143da 100644 --- a/cookies.go +++ b/cookies.go @@ -27,17 +27,16 @@ import ( // // dropCookie drops a cookie into the response // -func dropCookie(cx *gin.Context, name, value string, expires time.Time) { +func dropCookie(cx *gin.Context, name, value string, duration time.Duration) { cookie := &http.Cookie{ - Name: name, - Domain: strings.Split(cx.Request.Host, ":")[0], - Path: "/", - HttpOnly: true, - Secure: true, - Value: value, + Name: name, + Domain: strings.Split(cx.Request.Host, ":")[0], + Path: "/", + Secure: true, + Value: value, } - if !expires.IsZero() { - cookie.Expires = expires + if duration != 0 { + cookie.Expires = time.Now().Add(duration) } http.SetCookie(cx.Writer, cookie) @@ -46,15 +45,15 @@ func dropCookie(cx *gin.Context, name, value string, expires time.Time) { // // dropAccessTokenCookie drops a access token cookie into the response // -func dropAccessTokenCookie(cx *gin.Context, token jose.JWT) { - dropCookie(cx, cookieAccessToken, token.Encode(), time.Time{}) +func dropAccessTokenCookie(cx *gin.Context, token jose.JWT, duration time.Duration) { + dropCookie(cx, cookieAccessToken, token.Encode(), duration) } // // dropRefreshTokenCookie drops a refresh token cookie into the response // -func dropRefreshTokenCookie(cx *gin.Context, token string, expires time.Time) { - dropCookie(cx, cookieRefreshToken, token, expires) +func dropRefreshTokenCookie(cx *gin.Context, token string, duration time.Duration) { + dropCookie(cx, cookieRefreshToken, token, duration) } // @@ -69,12 +68,12 @@ func clearAllCookies(cx *gin.Context) { // clearRefreshSessionCookie clears the session cookie // func clearRefreshTokenCookie(cx *gin.Context) { - dropCookie(cx, cookieRefreshToken, "", time.Now().Add(-1*time.Hour)) + dropCookie(cx, cookieRefreshToken, "", time.Duration(-10*time.Hour)) } // // clearAccessTokenCookie clears the session cookie // func clearAccessTokenCookie(cx *gin.Context) { - dropCookie(cx, cookieAccessToken, "", time.Now().Add(-1*time.Hour)) + dropCookie(cx, cookieAccessToken, "", time.Duration(-10*time.Hour)) } diff --git a/cookies_test.go b/cookies_test.go index d087cb87a8f5c1cfc99ffae674f7098462e6438b..852bfecb3f6480da6707e0511941a759f9d5c5b8 100644 --- a/cookies_test.go +++ b/cookies_test.go @@ -14,3 +14,48 @@ limitations under the License. */ package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDropCookie(t *testing.T) { + context := newFakeGinContext("GET", "/admin") + dropCookie(context, "test-cookie", "test-value", 0) + + assert.Equal(t, context.Writer.Header().Get("Set-Cookie"), + "test-cookie=test-value; Path=/; Domain=127.0.0.1; Secure", + "we have not set the cookie, headers: %v", context.Writer.Header()) + + context = newFakeGinContext("GET", "/admin") + dropCookie(context, "test-cookie", "test-value", 0) + assert.NotEqual(t, context.Writer.Header().Get("Set-Cookie"), + "test-cookie=test-value; Path=/; Domain=127.0.0.2; HttpOnly; Secure", + "we have not set the cookie, headers: %v", context.Writer.Header()) +} + +func TestClearAccessTokenCookie(t *testing.T) { + context := newFakeGinContext("GET", "/admin") + clearAccessTokenCookie(context) + assert.Contains(t, context.Writer.Header().Get("Set-Cookie"), + "kc-access=; Path=/; Domain=127.0.0.1; Expires=", + "we have not cleared the, headers: %v", context.Writer.Header()) +} + +func TestClearRefreshAccessTokenCookie(t *testing.T) { + context := newFakeGinContext("GET", "/admin") + clearRefreshTokenCookie(context) + assert.Contains(t, context.Writer.Header().Get("Set-Cookie"), + "kc-state=; Path=/; Domain=127.0.0.1; Expires=", + "we have not cleared the, headers: %v", context.Writer.Header()) +} + +func TestClearAllCookies(t *testing.T) { + context := newFakeGinContext("GET", "/admin") + clearAllCookies(context) + assert.Contains(t, context.Writer.Header().Get("Set-Cookie"), + "kc-access=; Path=/; Domain=127.0.0.1; Expires=", + "we have not cleared the, headers: %v", context.Writer.Header()) +} diff --git a/doc.go b/doc.go index 692e565c53cac50d7bf85380b2275c86ca28f17c..9224ba54d01ffa27bad86f72ecaa3d5545537cf9 100644 --- a/doc.go +++ b/doc.go @@ -20,9 +20,11 @@ import ( "time" ) +const gitSHA = "v1.0.3-2-g0082034-dirty" + const ( prog = "keycloak-proxy" - version = "v1.0.3" + version = "v1.0.3" + " (git+sha: " + gitSHA + ")" author = "Rohith" email = "gambol99@gmail.com" description = "is a proxy using the keycloak service for auth and authorization" @@ -101,7 +103,7 @@ type Config struct { // DiscoveryURL is the url for the keycloak server DiscoveryURL string `json:"discovery-url" yaml:"discovery-url"` // ClientID is the client id - ClientID string `json:"clientid" yaml:"clientid"` + ClientID string `json:"client-id" yaml:"client-id"` // ClientSecret is the secret for AS ClientSecret string `json:"client-secret" yaml:"client-secret"` // RevocationEndpoint is the token revocation endpoint to revoke refresh tokens @@ -114,6 +116,8 @@ type Config struct { EnableSecurityFilter bool `json:"enable-security-filter" yaml:"enable-security-filter"` // EnableRefreshTokens indicate's you wish to ignore using refresh tokens and re-auth on expireation of access token EnableRefreshTokens bool `json:"enable-refresh-tokens" yaml:"enable-refresh-tokens"` + // IdleDuration is the max amount of time a session can last without being used + IdleDuration time.Duration `json:"idle-duration" yaml:"idle-duration"` // EncryptionKey is the encryption key used to encrypt the refresh token EncryptionKey string `json:"encryption-key" yaml:"encryption-key"` // ClaimsMatch is a series of checks, the claims in the token must match those here @@ -136,8 +140,8 @@ type Config struct { Upstream string `json:"upstream" yaml:"upstream"` // TagData is passed to the templates TagData map[string]string `json:"tag-data" yaml:"tag-data"` - // CORS permits adding headers to the /oauth handlers - CORS *CORS `json:"cors" yaml:"cors"` + // CrossOrigin permits adding headers to the /oauth handlers + CrossOrigin CORS `json:"cors" yaml:"cors"` // Header permits adding customs headers across the board Header map[string]string `json:"headers" yaml:"headers"` // Scopes is a list of scope we should request @@ -149,7 +153,7 @@ type Config struct { // ForbiddenPage is a access forbidden page ForbiddenPage string `json:"forbidden-page" yaml:"forbidden-page"` // SkipTokenVerification tells the service to skipp verifying the access token - for testing purposes - SkipTokenVerification bool + SkipTokenVerification bool `json:"skip-token-verification" yaml:"skip-token-verification"` // Verbose switches on debug logging Verbose bool `json:"verbose" yaml:"verbose"` // Hostname is a list of hostname's the service should response to @@ -158,9 +162,9 @@ type Config struct { StoreURL string `json:"store-url" yaml:"store-url"` } -// Store is used to hold the offline refresh token, assuming you don't want to use +// store is used to hold the offline refresh token, assuming you don't want to use // the default practice of a encrypted cookie -type Store interface { +type storage interface { // Add the token to the store Set(string, string) error // Get retrieves a token from the store diff --git a/handlers_auth.go b/handlers.go similarity index 66% rename from handlers_auth.go rename to handlers.go index 3ea0f53c226c45c9ffbb9e96adfd0c2e3cc13a9f..bd4e00ad3d98848122df4d09a69d206924a7d5eb 100644 --- a/handlers_auth.go +++ b/handlers.go @@ -21,191 +21,13 @@ import ( "net/http" "net/url" "path" + "strings" "time" log "github.com/Sirupsen/logrus" - "github.com/coreos/go-oidc/jose" "github.com/gin-gonic/gin" ) -// -// authenticationHandler is responsible for verifying the access token -// -func (r *oauthProxy) authenticationHandler() gin.HandlerFunc { - return func(cx *gin.Context) { - // step: is authentication required on this uri? - if _, found := cx.Get(cxEnforce); !found { - log.Debugf("skipping the authentication handler, resource not protected") - cx.Next() - return - } - - // step: grab the user identity from the request - user, err := getIdentity(cx) - if err != nil { - log.WithFields(log.Fields{ - "error": err.Error(), - }).Errorf("failed to get session, redirecting for authorization") - - r.redirectToAuthorization(cx) - return - } - - // step: inject the user into the context - cx.Set(userContextName, user) - - // step: verify the access token - if r.config.SkipTokenVerification { - log.Warnf("skip token verification enabled, skipping verification process - FOR TESTING ONLY") - - if user.isExpired() { - log.WithFields(log.Fields{ - "username": user.name, - "expired_on": user.expiresAt.String(), - }).Errorf("the session has expired and verification switch off") - - r.redirectToAuthorization(cx) - } - - return - } - - // step: verify the access token - if err := verifyToken(r.client, user.token); err != nil { - - // step: if the error post verification is anything other than a token expired error - // we immediately throw an access forbidden - as there is something messed up in the token - if err != ErrAccessTokenExpired { - log.WithFields(log.Fields{ - "error": err.Error(), - }).Errorf("verification of the access token failed") - - r.accessForbidden(cx) - return - } - - // step: are we refreshing the access tokens? - if !r.config.EnableRefreshTokens { - log.WithFields(log.Fields{ - "email": user.name, - "expired_on": user.expiresAt.String(), - }).Errorf("the session has expired and access token refreshing is disabled") - - r.redirectToAuthorization(cx) - return - } - - // step: we do not refresh bearer token requests - if user.isBearer() { - log.WithFields(log.Fields{ - "email": user.name, - "expired_on": user.expiresAt.String(), - }).Errorf("the session has expired and we are using bearer tokens") - - r.redirectToAuthorization(cx) - return - } - - log.WithFields(log.Fields{ - "email": user.email, - "client_ip": cx.ClientIP(), - }).Infof("the accces token for user: %s has expired, attemping to refresh the token", user.email) - - // step: check if the user has refresh token - rToken, err := r.retrieveRefreshToken(cx, user) - if err != nil { - log.WithFields(log.Fields{ - "email": user.email, - "error": err.Error(), - }).Errorf("unable to find a refresh token for the client: %s", user.email) - - r.redirectToAuthorization(cx) - return - } - - log.WithFields(log.Fields{ - "email": user.email, - }).Infof("found a refresh token, attempting to refresh access token for user: %s", user.email) - - // step: attempts to refresh the access token - token, expires, err := refreshToken(r.client, rToken) - if err != nil { - // step: has the refresh token expired - switch err { - case ErrRefreshTokenExpired: - log.WithFields(log.Fields{"token": token}).Warningf("the refresh token has expired") - clearAllCookies(cx) - default: - log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to refresh the access token") - } - - r.redirectToAuthorization(cx) - return - } - - // step: inject the refreshed access token - log.WithFields(log.Fields{ - "email": user.email, - "access_expires_in": expires.Sub(time.Now()).String(), - }).Infof("injecting refreshed access token, expires on: %s", expires.Format(time.RFC1123)) - - // step: clear the cookie up - dropAccessTokenCookie(cx, token) - - if r.useStore() { - go func(t jose.JWT, rt string) { - // step: the access token has been updated, we need to delete old reference and update the store - if err := r.DeleteRefreshToken(t); err != nil { - log.WithFields(log.Fields{ - "error": err.Error(), - }).Errorf("unable to delete the old refresh tokem from store") - } - - // step: store the new refresh token reference place the session in the store - if err := r.StoreRefreshToken(t, rt); err != nil { - log.WithFields(log.Fields{ - "error": err.Error(), - }).Errorf("failed to place the refresh token in the store") - - return - } - }(user.token, rToken) - } - - // step: update the with the new access token - user.token = token - - // step: inject the user into the context - cx.Set(userContextName, user) - } - - cx.Next() - } -} - -// -// retrieveRefreshToken retrieves the refresh token from store or c -// -func (r oauthProxy) retrieveRefreshToken(cx *gin.Context, user *userContext) (string, error) { - var token string - var err error - - // step: get the refresh token from the store or cookie - switch r.useStore() { - case true: - token, err = r.GetRefreshToken(user.token) - default: - token, err = getRefreshTokenFromCookie(cx) - } - - // step: decode the cookie - if err != nil { - return token, err - } - - return decodeText(token, r.config.EncryptionKey) -} - // // oauthAuthorizationHandler is responsible for performing the redirection to oauth provider // @@ -325,10 +147,11 @@ func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) { "email": identity.Email, "expires": identity.ExpiresAt.Format(time.RFC822Z), "duration": identity.ExpiresAt.Sub(time.Now()).String(), + "idle": r.config.IdleDuration.String(), }).Infof("issuing a new access token for user, email: %s", identity.Email) // step: drop's a session cookie with the access token - dropAccessTokenCookie(cx, session) + dropAccessTokenCookie(cx, session, r.config.IdleDuration) // step: does the response has a refresh token and we are NOT ignore refresh tokens? if r.config.EnableRefreshTokens && response.RefreshToken != "" { @@ -352,7 +175,7 @@ func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) { }).Warnf("failed to save the refresh token in the store") } default: - dropRefreshTokenCookie(cx, encrypted, time.Time{}) + dropRefreshTokenCookie(cx, encrypted, r.config.IdleDuration*2) } } @@ -501,3 +324,115 @@ func (r oauthProxy) logoutHandler(cx *gin.Context) { cx.AbortWithStatus(http.StatusOK) } + +// +// proxyHandler is responsible to proxy the requests on to the upstream endpoint +// +func (r *oauthProxy) proxyHandler(cx *gin.Context) { + // step: double check, if enforce is true and no user context it's a internal error + if _, found := cx.Get(cxEnforce); found { + if _, found := cx.Get(userContextName); !found { + log.Errorf("no user context found for a secure request") + cx.AbortWithStatus(http.StatusInternalServerError) + return + } + } + + // step: retrieve the user context if any + if user, found := cx.Get(userContextName); found { + id := user.(*userContext) + cx.Request.Header.Add("X-Auth-UserId", id.id) + cx.Request.Header.Add("X-Auth-Subject", id.preferredName) + cx.Request.Header.Add("X-Auth-Username", id.name) + cx.Request.Header.Add("X-Auth-Email", id.email) + cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String()) + cx.Request.Header.Add("X-Auth-Token", id.token.Encode()) + cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ",")) + } + + // step: add the default headers + cx.Request.Header.Add("X-Forwarded-For", cx.Request.RemoteAddr) + cx.Request.Header.Set("X-Forwarded-Agent", prog) + cx.Request.Header.Set("X-Forwarded-Agent-Version", version) + + // step: is this connection upgrading? + if isUpgradedConnection(cx.Request) { + log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade)) + if err := tryUpdateConnection(cx, r.endpoint); err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection") + + cx.AbortWithStatus(http.StatusInternalServerError) + return + } + cx.Abort() + + return + } + + r.upstream.ServeHTTP(cx.Writer, cx.Request) +} + +// +// expirationHandler checks if the token has expired +// +func (r *oauthProxy) expirationHandler(cx *gin.Context) { + // step: get the access token from the request + user, err := getIdentity(cx) + if err != nil { + cx.AbortWithError(http.StatusUnauthorized, err) + return + } + // step: check the access is not expired + if user.isExpired() { + cx.AbortWithError(http.StatusUnauthorized, err) + return + } + + cx.AbortWithStatus(http.StatusOK) +} + +// +// tokenHandler display access token to screen +// +func (r *oauthProxy) tokenHandler(cx *gin.Context) { + // step: extract the access token from the request + user, err := getIdentity(cx) + if err != nil { + cx.AbortWithError(http.StatusBadRequest, fmt.Errorf("unable to retrieve session, error: %s", err)) + return + } + + // step: write the json content + cx.Writer.Header().Set("Content-Type", "application/json") + cx.String(http.StatusOK, fmt.Sprintf("%s", user.token.Payload)) +} + +// +// healthHandler is a health check handler for the service +// +func (r *oauthProxy) healthHandler(cx *gin.Context) { + cx.String(http.StatusOK, "OK") +} + +// +// retrieveRefreshToken retrieves the refresh token from store or c +// +func (r oauthProxy) retrieveRefreshToken(cx *gin.Context, user *userContext) (string, error) { + var token string + var err error + + // step: get the refresh token from the store or cookie + switch r.useStore() { + case true: + token, err = r.GetRefreshToken(user.token) + default: + token, err = getRefreshTokenFromCookie(cx) + } + + // step: decode the cookie + if err != nil { + return token, err + } + + return decodeText(token, r.config.EncryptionKey) +} diff --git a/handlers_admission.go b/handlers_admission.go deleted file mode 100644 index 127dda54221a934adbba056399202e5e4c06aed9..0000000000000000000000000000000000000000 --- a/handlers_admission.go +++ /dev/null @@ -1,131 +0,0 @@ -/* -Copyright 2015 All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "regexp" - "time" - - log "github.com/Sirupsen/logrus" - "github.com/gin-gonic/gin" -) - -// -// admissionHandler is responsible checking the access token against the protected resource -// -func (r *oauthProxy) admissionHandler() gin.HandlerFunc { - // step: compile the regex's for the claims - claimMatches := make(map[string]*regexp.Regexp, 0) - for k, v := range r.config.ClaimsMatch { - claimMatches[k] = regexp.MustCompile(v) - } - - return func(cx *gin.Context) { - // step: if authentication is required on this, grab the resource spec - ur, found := cx.Get(cxEnforce) - if !found { - return - } - - // step: grab the identity from the context - uc, found := cx.Get(userContextName) - if !found { - panic("there is no identity in the request context") - } - - resource := ur.(*Resource) - user := uc.(*userContext) - - // step: check the audience for the token is us - if !user.isAudience(r.config.ClientID) { - log.WithFields(log.Fields{ - "username": user.name, - "expired_on": user.expiresAt.String(), - "issued": user.audience, - "clientid": r.config.ClientID, - }).Warnf("the access token audience is not us, redirecting back for authentication") - - r.accessForbidden(cx) - return - } - - // step: we need to check the roles - if roles := len(resource.Roles); roles > 0 { - if !hasRoles(resource.Roles, user.roles) { - log.WithFields(log.Fields{ - "access": "denied", - "username": user.name, - "resource": resource.URL, - "required": resource.GetRoles(), - }).Warnf("access denied, invalid roles") - - r.accessForbidden(cx) - return - } - } - - // step: if we have any claim matching, validate the tokens has the claims - for claimName, match := range claimMatches { - // step: if the claim is NOT in the token, we access deny - value, found, err := user.claims.StringClaim(claimName) - if err != nil { - log.WithFields(log.Fields{ - "access": "denied", - "username": user.name, - "resource": resource.URL, - "error": err.Error(), - }).Errorf("unable to extract the claim from token") - - r.accessForbidden(cx) - return - } - - if !found { - log.WithFields(log.Fields{ - "access": "denied", - "username": user.name, - "resource": resource.URL, - "claim": claimName, - }).Warnf("the token does not have the claim") - - r.accessForbidden(cx) - return - } - - // step: check the claim is the same - if !match.MatchString(value) { - log.WithFields(log.Fields{ - "access": "denied", - "username": user.name, - "resource": resource.URL, - "claim": claimName, - "issued": value, - "required": match, - }).Warnf("the token claims does not match claim requirement") - - r.accessForbidden(cx) - return - } - } - - log.WithFields(log.Fields{ - "access": "permitted", - "username": user.name, - "resource": resource.URL, - "expires": user.expiresAt.Sub(time.Now()).String(), - }).Debugf("resource access permitted: %s", cx.Request.RequestURI) - } -} diff --git a/handlers_admission_test.go b/handlers_admission_test.go deleted file mode 100644 index c5a1aeb401d3cacc4485ec2dd7ea7abea5b790d7..0000000000000000000000000000000000000000 --- a/handlers_admission_test.go +++ /dev/null @@ -1,237 +0,0 @@ -/* -Copyright 2015 All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "net/http" - "strings" - "testing" - - "github.com/coreos/go-oidc/jose" - "github.com/gin-gonic/gin" -) - -func TestAdmissionHandlerRoles(t *testing.T) { - proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ - { - URL: "/admin", - Methods: []string{"ANY"}, - Roles: []string{"admin"}, - }, - { - URL: "/test", - Methods: []string{"GET"}, - Roles: []string{"test"}, - }, - { - URL: "/either", - Methods: []string{"ANY"}, - Roles: []string{"admin", "test"}, - }, - { - URL: "/", - Methods: []string{"ANY"}, - }, - }) - handler := proxy.admissionHandler() - - tests := []struct { - Context *gin.Context - UserContext *userContext - HTTPCode int - }{ - { - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - }, - HTTPCode: http.StatusForbidden, - }, - { - Context: newFakeGinContext("GET", "/admin"), - HTTPCode: http.StatusOK, - UserContext: &userContext{ - audience: "test", - roles: []string{"admin"}, - }, - }, - { - Context: newFakeGinContext("GET", "/test"), - HTTPCode: http.StatusOK, - UserContext: &userContext{ - audience: "test", - roles: []string{"test"}, - }, - }, - { - Context: newFakeGinContext("GET", "/either"), - HTTPCode: http.StatusOK, - UserContext: &userContext{ - audience: "test", - roles: []string{"test", "admin"}, - }, - }, - { - Context: newFakeGinContext("GET", "/either"), - HTTPCode: http.StatusForbidden, - UserContext: &userContext{ - audience: "test", - roles: []string{"no_roles"}, - }, - }, - { - Context: newFakeGinContext("GET", "/"), - HTTPCode: http.StatusOK, - UserContext: &userContext{ - audience: "test", - }, - }, - } - - for i, c := range tests { - // step: find the resource and inject into the context - for _, r := range proxy.config.Resources { - if strings.HasPrefix(c.Context.Request.URL.Path, r.URL) { - c.Context.Set(cxEnforce, r) - break - } - } - if _, found := c.Context.Get(cxEnforce); !found { - t.Errorf("test case %d unable to find a resource for context", i) - continue - } - - c.Context.Set(userContextName, c.UserContext) - - handler(c.Context) - if c.Context.Writer.Status() != c.HTTPCode { - t.Errorf("test case %d should have recieved code: %d, got %d", i, c.HTTPCode, c.Context.Writer.Status()) - } - } -} - -func TestAdmissionHandlerClaims(t *testing.T) { - // allow any fake authd users - proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ - { - URL: "/admin", - Methods: []string{"ANY"}, - }, - }) - - tests := []struct { - Matches map[string]string - Context *gin.Context - UserContext *userContext - HTTPCode int - }{ - { - Matches: map[string]string{"iss": "test"}, - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - claims: jose.Claims{}, - }, - HTTPCode: http.StatusForbidden, - }, - { - Matches: map[string]string{"iss": "^tes$"}, - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - claims: jose.Claims{ - "aud": "test", - "iss": 1, - }, - }, - HTTPCode: http.StatusForbidden, - }, - { - Matches: map[string]string{"iss": "^tes$"}, - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - claims: jose.Claims{ - "aud": "test", - "iss": "bad_match", - }, - }, - HTTPCode: http.StatusForbidden, - }, - { - Matches: map[string]string{"iss": "^test", "notfound": "someting"}, - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - claims: jose.Claims{ - "aud": "test", - "iss": "test", - }, - }, - HTTPCode: http.StatusForbidden, - }, - { - Matches: map[string]string{"iss": "^test", "notfound": "someting"}, - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - claims: jose.Claims{ - "aud": "test", - "iss": "test", - }, - }, - HTTPCode: http.StatusForbidden, - }, - { - Matches: map[string]string{"iss": ".*"}, - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - claims: jose.Claims{ - "aud": "test", - "iss": "test", - }, - }, - HTTPCode: http.StatusOK, - }, - { - Matches: map[string]string{"iss": "^t.*$"}, - Context: newFakeGinContext("GET", "/admin"), - UserContext: &userContext{ - audience: "test", - claims: jose.Claims{"iss": "test"}, - }, - HTTPCode: http.StatusOK, - }, - } - - for i, c := range tests { - // step: if closure so we need to get the handler each time - proxy.config.ClaimsMatch = c.Matches - handler := proxy.admissionHandler() - // step: inject a resource - - c.Context.Set(cxEnforce, proxy.config.Resources[0]) - c.Context.Set(userContextName, c.UserContext) - - handler(c.Context) - c.Context.Writer.WriteHeaderNow() - - if c.Context.Writer.Status() != c.HTTPCode { - t.Errorf("test case %d should have recieved code: %d, got %d", i, c.HTTPCode, c.Context.Writer.Status()) - } - } -} diff --git a/handlers_auth_test.go b/handlers_auth_test.go deleted file mode 100644 index d087cb87a8f5c1cfc99ffae674f7098462e6438b..0000000000000000000000000000000000000000 --- a/handlers_auth_test.go +++ /dev/null @@ -1,16 +0,0 @@ -/* -Copyright 2015 All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main diff --git a/handlers_entry.go b/handlers_entry.go deleted file mode 100644 index f9e25218e0c2227658394ecd2365d3f3749ed6d1..0000000000000000000000000000000000000000 --- a/handlers_entry.go +++ /dev/null @@ -1,64 +0,0 @@ -/* -Copyright 2015 All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "strings" - - "github.com/gin-gonic/gin" -) - -const ( - // cxEnforce is the tag name for a request requiring - cxEnforce = "Enforcing" -) - -// -// entryPointHandler checks to see if the request requires authentication -// -func (r oauthProxy) entryPointHandler() gin.HandlerFunc { - return func(cx *gin.Context) { - if strings.HasPrefix(cx.Request.URL.Path, oauthURL) { - cx.Next() - return - } - - // step: check if authentication is required - gin doesn't support wildcard url, so we have have to use prefixes - for _, resource := range r.config.Resources { - if strings.HasPrefix(cx.Request.URL.Path, resource.URL) { - if resource.WhiteListed { - break - } - // step: inject the resource into the context, saves us from doing this again - if containedIn("ANY", resource.Methods) || containedIn(cx.Request.Method, resource.Methods) { - cx.Set(cxEnforce, resource) - } - break - } - } - // step: pass into the authentication and admission handlers - cx.Next() - - // step: add a custom headers to the request - for k, v := range r.config.Header { - cx.Request.Header.Set(k, v) - } - // step: check the request has not been aborted and if not, proxy request - if !cx.IsAborted() { - r.proxyHandler(cx) - } - } -} diff --git a/handlers_entry_test.go b/handlers_entry_test.go deleted file mode 100644 index 1a34eeb37e29c18a4562a237939e5238d9885f11..0000000000000000000000000000000000000000 --- a/handlers_entry_test.go +++ /dev/null @@ -1,174 +0,0 @@ -/* -Copyright 2015 All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "testing" - - "github.com/gin-gonic/gin" -) - -func TestEntrypointHandlerSecure(t *testing.T) { - proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ - { - URL: "/admin/white_listed", - WhiteListed: true, - }, - { - URL: "/admin", - Methods: []string{"ANY"}, - }, - { - URL: "/", - Methods: []string{"POST"}, - Roles: []string{"test"}, - }, - }) - - handler := proxy.entryPointHandler() - - tests := []struct { - Context *gin.Context - Secure bool - }{ - {Context: newFakeGinContext("GET", "/")}, - {Context: newFakeGinContext("GET", "/admin"), Secure: true}, - {Context: newFakeGinContext("GET", "/admin/white_listed")}, - {Context: newFakeGinContext("GET", "/admin/white"), Secure: true}, - {Context: newFakeGinContext("GET", "/not_secure")}, - {Context: newFakeGinContext("POST", "/"), Secure: true}, - } - - for i, c := range tests { - handler(c.Context) - _, found := c.Context.Get(cxEnforce) - if c.Secure && !found { - t.Errorf("test case %d should have been set secure", i) - } - if !c.Secure && found { - t.Errorf("test case %d should not have been set secure", i) - } - } -} - -func TestEntrypointMethods(t *testing.T) { - proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ - { - URL: "/u0", - Methods: []string{"GET", "POST"}, - }, - { - URL: "/u1", - Methods: []string{"ANY"}, - }, - { - URL: "/u2", - Methods: []string{"POST", "PUT"}, - }, - }) - - handler := proxy.entryPointHandler() - - tests := []struct { - Context *gin.Context - Secure bool - }{ - {Context: newFakeGinContext("GET", "/u0"), Secure: true}, - {Context: newFakeGinContext("POST", "/u0"), Secure: true}, - {Context: newFakeGinContext("PUT", "/u0"), Secure: false}, - {Context: newFakeGinContext("GET", "/u1"), Secure: true}, - {Context: newFakeGinContext("POST", "/u1"), Secure: true}, - {Context: newFakeGinContext("PATCH", "/u1"), Secure: true}, - {Context: newFakeGinContext("POST", "/u2"), Secure: true}, - {Context: newFakeGinContext("PUT", "/u2"), Secure: true}, - {Context: newFakeGinContext("DELETE", "/u2"), Secure: false}, - } - - for i, c := range tests { - handler(c.Context) - _, found := c.Context.Get(cxEnforce) - if c.Secure && !found { - t.Errorf("test case %d should have been set secure", i) - } - if !c.Secure && found { - t.Errorf("test case %d should not have been set secure", i) - } - } -} - -func TestEntrypointWhiteListing(t *testing.T) { - proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ - { - URL: "/admin/white_listed", - WhiteListed: true, - }, - { - URL: "/admin", - Methods: []string{"ANY"}, - }, - }) - handler := proxy.entryPointHandler() - - tests := []struct { - Context *gin.Context - Secure bool - }{ - {Context: newFakeGinContext("GET", "/")}, - {Context: newFakeGinContext("GET", "/admin"), Secure: true}, - {Context: newFakeGinContext("GET", "/admin/white_listed")}, - } - - for i, c := range tests { - handler(c.Context) - _, found := c.Context.Get(cxEnforce) - if c.Secure && !found { - t.Errorf("test case %d should have been set secure", i) - } - if !c.Secure && found { - t.Errorf("test case %d should not have been set secure", i) - } - } - -} - -func TestEntrypointHandler(t *testing.T) { - proxy := newFakeKeycloakProxy(t) - handler := proxy.entryPointHandler() - - tests := []struct { - Context *gin.Context - Secure bool - }{ - {Context: newFakeGinContext("GET", fakeAdminRoleURL), Secure: true}, - {Context: newFakeGinContext("GET", fakeAdminRoleURL+"/sso"), Secure: true}, - {Context: newFakeGinContext("GET", fakeAdminRoleURL+"/../sso"), Secure: true}, - {Context: newFakeGinContext("GET", "/not_secure")}, - {Context: newFakeGinContext("GET", fakeTestWhitelistedURL)}, - {Context: newFakeGinContext("GET", oauthURL)}, - {Context: newFakeGinContext("GET", fakeTestListenOrdered), Secure: true}, - } - - for i, c := range tests { - handler(c.Context) - _, found := c.Context.Get(cxEnforce) - if c.Secure && !found { - t.Errorf("test case %d should have been set secure", i) - } - if !c.Secure && found { - t.Errorf("test case %d should not have been set secure", i) - } - } -} diff --git a/handlers_utils_test.go b/handlers_test.go similarity index 59% rename from handlers_utils_test.go rename to handlers_test.go index 10abbb3b07065259e642d2fba3448cbe1afced8e..14d124eca7559e3235c9850c2644413aa76e8940 100644 --- a/handlers_utils_test.go +++ b/handlers_test.go @@ -82,80 +82,6 @@ func TestExpirationHandler(t *testing.T) { } } -func TestCrossSiteHandler(t *testing.T) { - kc := newFakeKeycloakProxy(t) - handler := kc.crossSiteHandler() - - cases := []struct { - Cors *CORS - Headers map[string]string - }{ - { - Cors: &CORS{ - Origins: []string{"*"}, - }, - Headers: map[string]string{ - "Access-Control-Allow-Origin": "*", - }, - }, - { - Cors: &CORS{ - Origins: []string{"*", "https://examples.com"}, - Methods: []string{"GET"}, - }, - Headers: map[string]string{ - "Access-Control-Allow-Origin": "*,https://examples.com", - "Access-Control-Allow-Methods": "GET", - }, - }, - } - - for i, c := range cases { - // step: get the config - kc.config.CORS = c.Cors - // call the handler and check the responses - context := newFakeGinContext("GET", "/oauth/test") - handler(context) - // step: check the headers - for k, v := range c.Headers { - value := context.Writer.Header().Get(k) - if value == "" { - t.Errorf("case %d, should have had the %s header set, headers: %v", i, k, context.Writer.Header()) - continue - } - if value != v { - t.Errorf("case %d, expected: %s but got %s", i, k, value) - } - } - } -} - -func TestSecurityHandler(t *testing.T) { - kc := newFakeKeycloakProxy(t) - handler := kc.securityHandler() - context := newFakeGinContext("GET", "/") - handler(context) - if context.Writer.Status() != http.StatusOK { - t.Errorf("we should have received a 200") - } - - kc = newFakeKeycloakProxy(t) - kc.config.Hostnames = []string{"127.0.0.1"} - handler = kc.securityHandler() - handler(context) - if context.Writer.Status() != http.StatusOK { - t.Errorf("we should have received a 200 not %d", context.Writer.Status()) - } - - kc = newFakeKeycloakProxy(t) - kc.config.Hostnames = []string{"127.0.0.2"} - handler = kc.securityHandler() - handler(context) - if context.Writer.Status() != http.StatusInternalServerError { - t.Errorf("we should have received a 500 not %d", context.Writer.Status()) - } -} - func TestHealthHandler(t *testing.T) { proxy := newFakeKeycloakProxy(t) context := newFakeGinContext("GET", healthURL) diff --git a/handlers_utils.go b/handlers_utils.go deleted file mode 100644 index d62e37f4b618b69fbfde4565606b76008397c4b9..0000000000000000000000000000000000000000 --- a/handlers_utils.go +++ /dev/null @@ -1,189 +0,0 @@ -/* -Copyright 2015 All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "fmt" - "net/http" - "strings" - "time" - - log "github.com/Sirupsen/logrus" - "github.com/gin-gonic/gin" - "github.com/unrolled/secure" -) - -// -// loggingHandler is a custom http logger -// -func (r *oauthProxy) loggingHandler() gin.HandlerFunc { - return func(cx *gin.Context) { - start := time.Now() - cx.Next() - latency := time.Now().Sub(start) - - log.WithFields(log.Fields{ - "client_ip": cx.ClientIP(), - "method": cx.Request.Method, - "status": cx.Writer.Status(), - "bytes": cx.Writer.Size(), - "path": cx.Request.URL.Path, - "latency": latency.String(), - }).Infof("[%d] |%s| |%10v| %-5s %s", cx.Writer.Status(), cx.ClientIP(), latency, cx.Request.Method, cx.Request.URL.Path) - } -} - -// -// securityHandler performs numerous security checks on the request -// -func (r *oauthProxy) securityHandler() gin.HandlerFunc { - // step: create the security options - secure := secure.New(secure.Options{ - AllowedHosts: r.config.Hostnames, - BrowserXssFilter: true, - ContentTypeNosniff: true, - FrameDeny: true, - STSIncludeSubdomains: true, - STSSeconds: 31536000, - }) - - return func(cx *gin.Context) { - // step: pass through the security middleware - if err := secure.Process(cx.Writer, cx.Request); err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed security middleware") - cx.Abort() - return - } - // step: permit the request to continue - cx.Next() - } -} - -// -// crossSiteHandler injects the CORS headers, if set, for request made to /oauth -// -func (r *oauthProxy) crossSiteHandler() gin.HandlerFunc { - return func(cx *gin.Context) { - c := r.config.CORS - if len(c.Origins) > 0 { - cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ",")) - } - if len(c.Methods) > 0 { - cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ",")) - } - if len(c.Headers) > 0 { - cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ",")) - } - if len(c.ExposedHeaders) > 0 { - cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ",")) - } - if c.Credentials { - cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") - } - if c.MaxAge > 0 { - cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds()))) - } - } -} - -// -// proxyHandler is responsible to proxy the requests on to the upstream endpoint -// -func (r *oauthProxy) proxyHandler(cx *gin.Context) { - // step: double check, if enforce is true and no user context it's a internal error - if _, found := cx.Get(cxEnforce); found { - if _, found := cx.Get(userContextName); !found { - log.Errorf("no user context found for a secure request") - cx.AbortWithStatus(http.StatusInternalServerError) - return - } - } - - // step: retrieve the user context if any - if user, found := cx.Get(userContextName); found { - id := user.(*userContext) - cx.Request.Header.Add("X-Auth-UserId", id.id) - cx.Request.Header.Add("X-Auth-Subject", id.preferredName) - cx.Request.Header.Add("X-Auth-Username", id.name) - cx.Request.Header.Add("X-Auth-Email", id.email) - cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String()) - cx.Request.Header.Add("X-Auth-Token", id.token.Encode()) - cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ",")) - } - - // step: add the default headers - cx.Request.Header.Add("X-Forwarded-For", cx.Request.RemoteAddr) - cx.Request.Header.Set("X-Forwarded-Agent", prog) - cx.Request.Header.Set("X-Forwarded-Agent-Version", version) - - // step: is this connection upgrading? - if isUpgradedConnection(cx.Request) { - log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade)) - if err := tryUpdateConnection(cx, r.endpoint); err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection") - - cx.AbortWithStatus(http.StatusInternalServerError) - return - } - cx.Abort() - - return - } - - r.upstream.ServeHTTP(cx.Writer, cx.Request) -} - -// -// expirationHandler checks if the token has expired -// -func (r *oauthProxy) expirationHandler(cx *gin.Context) { - // step: get the access token from the request - user, err := getIdentity(cx) - if err != nil { - cx.AbortWithError(http.StatusUnauthorized, err) - return - } - // step: check the access is not expired - if user.isExpired() { - cx.AbortWithError(http.StatusUnauthorized, err) - return - } - - cx.AbortWithStatus(http.StatusOK) -} - -// -// tokenHandler display access token to screen -// -func (r *oauthProxy) tokenHandler(cx *gin.Context) { - // step: extract the access token from the request - user, err := getIdentity(cx) - if err != nil { - cx.AbortWithError(http.StatusBadRequest, fmt.Errorf("unable to retrieve session, error: %s", err)) - return - } - - // step: write the json content - cx.Writer.Header().Set("Content-Type", "application/json") - cx.String(http.StatusOK, fmt.Sprintf("%s", user.token.Payload)) -} - -// -// healthHandler is a health check handler for the service -// -func (r *oauthProxy) healthHandler(cx *gin.Context) { - cx.String(http.StatusOK, "OK") -} diff --git a/main.go b/main.go index 988f01c1daf08787ca07e086b27776c7155ada77..e0d47bca993214895977c5d46ab0015f06ce1190 100644 --- a/main.go +++ b/main.go @@ -36,8 +36,8 @@ func main() { kc.Action = func(cx *cli.Context) { // step: do we have a configuration file? if filename := cx.String("config"); filename != "" { - if err := readConfigFile(cx.String("config"), config); err != nil { - printUsage(err.Error()) + if err := readConfigFile(filename, config); err != nil { + printUsage(fmt.Sprintf("unable to read the configuration file: %s, error: %s", filename, err.Error())) } } // step: parse the command line options diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000000000000000000000000000000000000..b020cfd223749830e77d901641cd67680e74aab0 --- /dev/null +++ b/middleware.go @@ -0,0 +1,407 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "regexp" + "strings" + "time" + + log "github.com/Sirupsen/logrus" + "github.com/coreos/go-oidc/jose" + "github.com/gin-gonic/gin" + "github.com/unrolled/secure" +) + +const ( + // cxEnforce is the tag name for a request requiring + cxEnforce = "Enforcing" +) + +// +// loggingHandler is a custom http logger +// +func (r *oauthProxy) loggingHandler() gin.HandlerFunc { + return func(cx *gin.Context) { + start := time.Now() + cx.Next() + latency := time.Now().Sub(start) + + log.WithFields(log.Fields{ + "client_ip": cx.ClientIP(), + "method": cx.Request.Method, + "status": cx.Writer.Status(), + "bytes": cx.Writer.Size(), + "path": cx.Request.URL.Path, + "latency": latency.String(), + }).Infof("[%d] |%s| |%10v| %-5s %s", cx.Writer.Status(), cx.ClientIP(), latency, cx.Request.Method, cx.Request.URL.Path) + } +} + +// +// entryPointHandler checks to see if the request requires authentication +// +func (r oauthProxy) entryPointHandler() gin.HandlerFunc { + return func(cx *gin.Context) { + if strings.HasPrefix(cx.Request.URL.Path, oauthURL) { + cx.Next() + return + } + + // step: check if authentication is required - gin doesn't support wildcard url, so we have have to use prefixes + for _, resource := range r.config.Resources { + if strings.HasPrefix(cx.Request.URL.Path, resource.URL) { + if resource.WhiteListed { + break + } + // step: inject the resource into the context, saves us from doing this again + if containedIn("ANY", resource.Methods) || containedIn(cx.Request.Method, resource.Methods) { + cx.Set(cxEnforce, resource) + } + break + } + } + // step: pass into the authentication and admission handlers + cx.Next() + + // step: add a custom headers to the request + for k, v := range r.config.Header { + cx.Request.Header.Set(k, v) + } + // step: check the request has not been aborted and if not, proxy request + if !cx.IsAborted() { + r.proxyHandler(cx) + } + } +} + +// +// crossOriginResourceHandler injects the CORS headers, if set, for request made to /oauth +// +func (r *oauthProxy) crossOriginResourceHandler(c CORS) gin.HandlerFunc { + return func(cx *gin.Context) { + if len(c.Origins) > 0 { + cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ",")) + } + if len(c.Methods) > 0 { + cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ",")) + } + if len(c.Headers) > 0 { + cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ",")) + } + if len(c.ExposedHeaders) > 0 { + cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ",")) + } + if c.Credentials { + cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + } + if c.MaxAge > 0 { + cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds()))) + } + } +} + +// +// securityHandler performs numerous security checks on the request +// +func (r *oauthProxy) securityHandler() gin.HandlerFunc { + // step: create the security options + secure := secure.New(secure.Options{ + AllowedHosts: r.config.Hostnames, + BrowserXssFilter: true, + ContentTypeNosniff: true, + FrameDeny: true, + STSIncludeSubdomains: true, + STSSeconds: 31536000, + }) + + return func(cx *gin.Context) { + // step: pass through the security middleware + if err := secure.Process(cx.Writer, cx.Request); err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed security middleware") + cx.Abort() + return + } + // step: permit the request to continue + cx.Next() + } +} + +// +// authenticationHandler is responsible for verifying the access token +// +func (r *oauthProxy) authenticationHandler() gin.HandlerFunc { + return func(cx *gin.Context) { + // step: is authentication required on this uri? + if _, found := cx.Get(cxEnforce); !found { + log.Debugf("skipping the authentication handler, resource not protected") + cx.Next() + return + } + + // step: grab the user identity from the request + user, err := getIdentity(cx) + if err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("failed to get session, redirecting for authorization") + + r.redirectToAuthorization(cx) + return + } + + // step: inject the user into the context + cx.Set(userContextName, user) + + // step: verify the access token + if r.config.SkipTokenVerification { + log.Warnf("skip token verification enabled, skipping verification process - FOR TESTING ONLY") + + if user.isExpired() { + log.WithFields(log.Fields{ + "username": user.name, + "expired_on": user.expiresAt.String(), + }).Errorf("the session has expired and verification switch off") + + r.redirectToAuthorization(cx) + } + + return + } + + // step: verify the access token + if err := verifyToken(r.client, user.token); err != nil { + + // step: if the error post verification is anything other than a token expired error + // we immediately throw an access forbidden - as there is something messed up in the token + if err != ErrAccessTokenExpired { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("verification of the access token failed") + + r.accessForbidden(cx) + return + } + + // step: are we refreshing the access tokens? + if !r.config.EnableRefreshTokens { + log.WithFields(log.Fields{ + "email": user.name, + "expired_on": user.expiresAt.String(), + }).Errorf("the session has expired and access token refreshing is disabled") + + r.redirectToAuthorization(cx) + return + } + + // step: we do not refresh bearer token requests + if user.isBearer() { + log.WithFields(log.Fields{ + "email": user.name, + "expired_on": user.expiresAt.String(), + }).Errorf("the session has expired and we are using bearer tokens") + + r.redirectToAuthorization(cx) + return + } + + log.WithFields(log.Fields{ + "email": user.email, + "client_ip": cx.ClientIP(), + }).Infof("the accces token for user: %s has expired, attemping to refresh the token", user.email) + + // step: check if the user has refresh token + rToken, err := r.retrieveRefreshToken(cx, user) + if err != nil { + log.WithFields(log.Fields{ + "email": user.email, + "error": err.Error(), + }).Errorf("unable to find a refresh token for the client: %s", user.email) + + r.redirectToAuthorization(cx) + return + } + + log.WithFields(log.Fields{ + "email": user.email, + }).Infof("found a refresh token, attempting to refresh access token for user: %s", user.email) + + // step: attempts to refresh the access token + token, expires, err := refreshToken(r.client, rToken) + if err != nil { + // step: has the refresh token expired + switch err { + case ErrRefreshTokenExpired: + log.WithFields(log.Fields{"token": token}).Warningf("the refresh token has expired") + clearAllCookies(cx) + default: + log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to refresh the access token") + } + + r.redirectToAuthorization(cx) + return + } + + // step: inject the refreshed access token + log.WithFields(log.Fields{ + "email": user.email, + "access_expires_in": expires.Sub(time.Now()).String(), + }).Infof("injecting refreshed access token, expires on: %s", expires.Format(time.RFC1123)) + + // step: clear the cookie up + dropAccessTokenCookie(cx, token, r.config.IdleDuration) + + if r.useStore() { + go func(t jose.JWT, rt string) { + // step: the access token has been updated, we need to delete old reference and update the store + if err := r.DeleteRefreshToken(t); err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("unable to delete the old refresh tokem from store") + } + + // step: store the new refresh token reference place the session in the store + if err := r.StoreRefreshToken(t, rt); err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("failed to place the refresh token in the store") + + return + } + }(user.token, rToken) + } else { + // step: update the expiration on the refresh token + dropRefreshTokenCookie(cx, rToken, r.config.IdleDuration*2) + } + + // step: update the with the new access token + user.token = token + + // step: inject the user into the context + cx.Set(userContextName, user) + } + + cx.Next() + } +} + +// +// admissionHandler is responsible checking the access token against the protected resource +// +func (r *oauthProxy) admissionHandler() gin.HandlerFunc { + // step: compile the regex's for the claims + claimMatches := make(map[string]*regexp.Regexp, 0) + for k, v := range r.config.ClaimsMatch { + claimMatches[k] = regexp.MustCompile(v) + } + + return func(cx *gin.Context) { + // step: if authentication is required on this, grab the resource spec + ur, found := cx.Get(cxEnforce) + if !found { + return + } + + // step: grab the identity from the context + uc, found := cx.Get(userContextName) + if !found { + panic("there is no identity in the request context") + } + + resource := ur.(*Resource) + user := uc.(*userContext) + + // step: check the audience for the token is us + if !user.isAudience(r.config.ClientID) { + log.WithFields(log.Fields{ + "username": user.name, + "expired_on": user.expiresAt.String(), + "issued": user.audience, + "clientid": r.config.ClientID, + }).Warnf("the access token audience is not us, redirecting back for authentication") + + r.accessForbidden(cx) + return + } + + // step: we need to check the roles + if roles := len(resource.Roles); roles > 0 { + if !hasRoles(resource.Roles, user.roles) { + log.WithFields(log.Fields{ + "access": "denied", + "username": user.name, + "resource": resource.URL, + "required": resource.GetRoles(), + }).Warnf("access denied, invalid roles") + + r.accessForbidden(cx) + return + } + } + + // step: if we have any claim matching, validate the tokens has the claims + for claimName, match := range claimMatches { + // step: if the claim is NOT in the token, we access deny + value, found, err := user.claims.StringClaim(claimName) + if err != nil { + log.WithFields(log.Fields{ + "access": "denied", + "username": user.name, + "resource": resource.URL, + "error": err.Error(), + }).Errorf("unable to extract the claim from token") + + r.accessForbidden(cx) + return + } + + if !found { + log.WithFields(log.Fields{ + "access": "denied", + "username": user.name, + "resource": resource.URL, + "claim": claimName, + }).Warnf("the token does not have the claim") + + r.accessForbidden(cx) + return + } + + // step: check the claim is the same + if !match.MatchString(value) { + log.WithFields(log.Fields{ + "access": "denied", + "username": user.name, + "resource": resource.URL, + "claim": claimName, + "issued": value, + "required": match, + }).Warnf("the token claims does not match claim requirement") + + r.accessForbidden(cx) + return + } + } + + log.WithFields(log.Fields{ + "access": "permitted", + "username": user.name, + "resource": resource.URL, + "expires": user.expiresAt.Sub(time.Now()).String(), + }).Debugf("resource access permitted: %s", cx.Request.RequestURI) + } +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ab23abf1fcbc10f594f5d9b07c955e49ef7c20b3 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,461 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "net/http" + "strings" + "testing" + + "github.com/coreos/go-oidc/jose" + "github.com/gin-gonic/gin" +) + +func TestEntrypointHandlerSecure(t *testing.T) { + proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ + { + URL: "/admin/white_listed", + WhiteListed: true, + }, + { + URL: "/admin", + Methods: []string{"ANY"}, + }, + { + URL: "/", + Methods: []string{"POST"}, + Roles: []string{"test"}, + }, + }) + + handler := proxy.entryPointHandler() + + tests := []struct { + Context *gin.Context + Secure bool + }{ + {Context: newFakeGinContext("GET", "/")}, + {Context: newFakeGinContext("GET", "/admin"), Secure: true}, + {Context: newFakeGinContext("GET", "/admin/white_listed")}, + {Context: newFakeGinContext("GET", "/admin/white"), Secure: true}, + {Context: newFakeGinContext("GET", "/not_secure")}, + {Context: newFakeGinContext("POST", "/"), Secure: true}, + } + + for i, c := range tests { + handler(c.Context) + _, found := c.Context.Get(cxEnforce) + if c.Secure && !found { + t.Errorf("test case %d should have been set secure", i) + } + if !c.Secure && found { + t.Errorf("test case %d should not have been set secure", i) + } + } +} + +func TestEntrypointMethods(t *testing.T) { + proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ + { + URL: "/u0", + Methods: []string{"GET", "POST"}, + }, + { + URL: "/u1", + Methods: []string{"ANY"}, + }, + { + URL: "/u2", + Methods: []string{"POST", "PUT"}, + }, + }) + + handler := proxy.entryPointHandler() + + tests := []struct { + Context *gin.Context + Secure bool + }{ + {Context: newFakeGinContext("GET", "/u0"), Secure: true}, + {Context: newFakeGinContext("POST", "/u0"), Secure: true}, + {Context: newFakeGinContext("PUT", "/u0"), Secure: false}, + {Context: newFakeGinContext("GET", "/u1"), Secure: true}, + {Context: newFakeGinContext("POST", "/u1"), Secure: true}, + {Context: newFakeGinContext("PATCH", "/u1"), Secure: true}, + {Context: newFakeGinContext("POST", "/u2"), Secure: true}, + {Context: newFakeGinContext("PUT", "/u2"), Secure: true}, + {Context: newFakeGinContext("DELETE", "/u2"), Secure: false}, + } + + for i, c := range tests { + handler(c.Context) + _, found := c.Context.Get(cxEnforce) + if c.Secure && !found { + t.Errorf("test case %d should have been set secure", i) + } + if !c.Secure && found { + t.Errorf("test case %d should not have been set secure", i) + } + } +} + +func TestEntrypointWhiteListing(t *testing.T) { + proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ + { + URL: "/admin/white_listed", + WhiteListed: true, + }, + { + URL: "/admin", + Methods: []string{"ANY"}, + }, + }) + handler := proxy.entryPointHandler() + + tests := []struct { + Context *gin.Context + Secure bool + }{ + {Context: newFakeGinContext("GET", "/")}, + {Context: newFakeGinContext("GET", "/admin"), Secure: true}, + {Context: newFakeGinContext("GET", "/admin/white_listed")}, + } + + for i, c := range tests { + handler(c.Context) + _, found := c.Context.Get(cxEnforce) + if c.Secure && !found { + t.Errorf("test case %d should have been set secure", i) + } + if !c.Secure && found { + t.Errorf("test case %d should not have been set secure", i) + } + } + +} + +func TestEntrypointHandler(t *testing.T) { + proxy := newFakeKeycloakProxy(t) + handler := proxy.entryPointHandler() + + tests := []struct { + Context *gin.Context + Secure bool + }{ + {Context: newFakeGinContext("GET", fakeAdminRoleURL), Secure: true}, + {Context: newFakeGinContext("GET", fakeAdminRoleURL+"/sso"), Secure: true}, + {Context: newFakeGinContext("GET", fakeAdminRoleURL+"/../sso"), Secure: true}, + {Context: newFakeGinContext("GET", "/not_secure")}, + {Context: newFakeGinContext("GET", fakeTestWhitelistedURL)}, + {Context: newFakeGinContext("GET", oauthURL)}, + {Context: newFakeGinContext("GET", fakeTestListenOrdered), Secure: true}, + } + + for i, c := range tests { + handler(c.Context) + _, found := c.Context.Get(cxEnforce) + if c.Secure && !found { + t.Errorf("test case %d should have been set secure", i) + } + if !c.Secure && found { + t.Errorf("test case %d should not have been set secure", i) + } + } +} + +func TestSecurityHandler(t *testing.T) { + kc := newFakeKeycloakProxy(t) + handler := kc.securityHandler() + context := newFakeGinContext("GET", "/") + handler(context) + if context.Writer.Status() != http.StatusOK { + t.Errorf("we should have received a 200") + } + + kc = newFakeKeycloakProxy(t) + kc.config.Hostnames = []string{"127.0.0.1"} + handler = kc.securityHandler() + handler(context) + if context.Writer.Status() != http.StatusOK { + t.Errorf("we should have received a 200 not %d", context.Writer.Status()) + } + + kc = newFakeKeycloakProxy(t) + kc.config.Hostnames = []string{"127.0.0.2"} + handler = kc.securityHandler() + handler(context) + if context.Writer.Status() != http.StatusInternalServerError { + t.Errorf("we should have received a 500 not %d", context.Writer.Status()) + } +} + +func TestCrossSiteHandler(t *testing.T) { + proxy := newFakeKeycloakProxy(t) + + cases := []struct { + Cors CORS + Headers map[string]string + }{ + { + Cors: CORS{ + Origins: []string{"*"}, + }, + Headers: map[string]string{ + "Access-Control-Allow-Origin": "*", + }, + }, + { + Cors: CORS{ + Origins: []string{"*", "https://examples.com"}, + Methods: []string{"GET"}, + }, + Headers: map[string]string{ + "Access-Control-Allow-Origin": "*,https://examples.com", + "Access-Control-Allow-Methods": "GET", + }, + }, + } + + for i, c := range cases { + handler := proxy.crossOriginResourceHandler(c.Cors) + // call the handler and check the responses + context := newFakeGinContext("GET", "/oauth/test") + handler(context) + // step: check the headers + for k, v := range c.Headers { + value := context.Writer.Header().Get(k) + if value == "" { + t.Errorf("case %d, should have had the %s header set, headers: %v", i, k, context.Writer.Header()) + continue + } + if value != v { + t.Errorf("case %d, expected: %s but got %s", i, k, value) + } + } + } +} + +func TestAdmissionHandlerRoles(t *testing.T) { + proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ + { + URL: "/admin", + Methods: []string{"ANY"}, + Roles: []string{"admin"}, + }, + { + URL: "/test", + Methods: []string{"GET"}, + Roles: []string{"test"}, + }, + { + URL: "/either", + Methods: []string{"ANY"}, + Roles: []string{"admin", "test"}, + }, + { + URL: "/", + Methods: []string{"ANY"}, + }, + }) + handler := proxy.admissionHandler() + + tests := []struct { + Context *gin.Context + UserContext *userContext + HTTPCode int + }{ + { + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + }, + HTTPCode: http.StatusForbidden, + }, + { + Context: newFakeGinContext("GET", "/admin"), + HTTPCode: http.StatusOK, + UserContext: &userContext{ + audience: "test", + roles: []string{"admin"}, + }, + }, + { + Context: newFakeGinContext("GET", "/test"), + HTTPCode: http.StatusOK, + UserContext: &userContext{ + audience: "test", + roles: []string{"test"}, + }, + }, + { + Context: newFakeGinContext("GET", "/either"), + HTTPCode: http.StatusOK, + UserContext: &userContext{ + audience: "test", + roles: []string{"test", "admin"}, + }, + }, + { + Context: newFakeGinContext("GET", "/either"), + HTTPCode: http.StatusForbidden, + UserContext: &userContext{ + audience: "test", + roles: []string{"no_roles"}, + }, + }, + { + Context: newFakeGinContext("GET", "/"), + HTTPCode: http.StatusOK, + UserContext: &userContext{ + audience: "test", + }, + }, + } + + for i, c := range tests { + // step: find the resource and inject into the context + for _, r := range proxy.config.Resources { + if strings.HasPrefix(c.Context.Request.URL.Path, r.URL) { + c.Context.Set(cxEnforce, r) + break + } + } + if _, found := c.Context.Get(cxEnforce); !found { + t.Errorf("test case %d unable to find a resource for context", i) + continue + } + + c.Context.Set(userContextName, c.UserContext) + + handler(c.Context) + if c.Context.Writer.Status() != c.HTTPCode { + t.Errorf("test case %d should have recieved code: %d, got %d", i, c.HTTPCode, c.Context.Writer.Status()) + } + } +} + +func TestAdmissionHandlerClaims(t *testing.T) { + // allow any fake authd users + proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ + { + URL: "/admin", + Methods: []string{"ANY"}, + }, + }) + + tests := []struct { + Matches map[string]string + Context *gin.Context + UserContext *userContext + HTTPCode int + }{ + { + Matches: map[string]string{"iss": "test"}, + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + claims: jose.Claims{}, + }, + HTTPCode: http.StatusForbidden, + }, + { + Matches: map[string]string{"iss": "^tes$"}, + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + claims: jose.Claims{ + "aud": "test", + "iss": 1, + }, + }, + HTTPCode: http.StatusForbidden, + }, + { + Matches: map[string]string{"iss": "^tes$"}, + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + claims: jose.Claims{ + "aud": "test", + "iss": "bad_match", + }, + }, + HTTPCode: http.StatusForbidden, + }, + { + Matches: map[string]string{"iss": "^test", "notfound": "someting"}, + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + claims: jose.Claims{ + "aud": "test", + "iss": "test", + }, + }, + HTTPCode: http.StatusForbidden, + }, + { + Matches: map[string]string{"iss": "^test", "notfound": "someting"}, + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + claims: jose.Claims{ + "aud": "test", + "iss": "test", + }, + }, + HTTPCode: http.StatusForbidden, + }, + { + Matches: map[string]string{"iss": ".*"}, + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + claims: jose.Claims{ + "aud": "test", + "iss": "test", + }, + }, + HTTPCode: http.StatusOK, + }, + { + Matches: map[string]string{"iss": "^t.*$"}, + Context: newFakeGinContext("GET", "/admin"), + UserContext: &userContext{ + audience: "test", + claims: jose.Claims{"iss": "test"}, + }, + HTTPCode: http.StatusOK, + }, + } + + for i, c := range tests { + // step: if closure so we need to get the handler each time + proxy.config.ClaimsMatch = c.Matches + handler := proxy.admissionHandler() + // step: inject a resource + + c.Context.Set(cxEnforce, proxy.config.Resources[0]) + c.Context.Set(userContextName, c.UserContext) + + handler(c.Context) + c.Context.Writer.WriteHeaderNow() + + if c.Context.Writer.Status() != c.HTTPCode { + t.Errorf("test case %d should have recieved code: %d, got %d", i, c.HTTPCode, c.Context.Writer.Status()) + } + } +} diff --git a/oauth_test.go b/oauth_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0825256205fd22e70d1cc2e32d88572539be2333 --- /dev/null +++ b/oauth_test.go @@ -0,0 +1,117 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "net/http" + "testing" + + "github.com/gin-gonic/gin" +) + +type fakeOAuthServer struct { +} + +type fakeDiscoveryResponse struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + EndSessionEndpoint string `json:"end_session_endpoint"` + GrantTypesSupported []string `json:"grant_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + Issuer string `json:"issuer"` + JwksURI string `json:"jwks_uri"` + RegistrationEndpoint string `json:"registration_endpoint"` + ResponseModesSupported []string `json:"response_modes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + TokenEndpoint string `json:"token_endpoint"` + TokenIntrospectionEndpoint string `json:"token_introspection_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` +} + +type fakeKeysResponse struct { + Keys []fakeKeyResponse `json:"keys"` +} + +type fakeKeyResponse struct { + Alg string `json:"alg"` + E string `json:"e"` + Kid string `json:"kid"` + Kty string `json:"kty"` + N string `json:"n"` + Use string `json:"use"` +} + +const ( + fakePublicKey = "ibGNjo_opyEGbeDP3cctILhSW-sGKtG67hCZXxvHx-wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5-FMbHth-TKZiEhm-3EBadc1qgkfnpinfpxCVqHHaF8mFLC5-k3JsINIR0FAmPN9trxryI_npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh_rrLKAs0AdUYwXGAslnYDBACiR8GNrb7Q" + + oauthPublicKey = "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQAB" + oauthPrivateKey = "MIIEowIBAAKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQABAoIBAGtfMlSmMbUKErpiIZX+uFkYgti8p92CGLOF7CN3RU3H+PgfF1m4xHGqt4xw+2JyhgQFgTY4IiIN1QuPFzI82+6jDvMqBwEi2e0TGj4RKiOX9D8b/qSL9eUSfqQKPqnPZfBymM3sqe5yddY7KVZiMXEEBu1efhhTADluIraKQjYJKgQd0P3CgfqhuUWgCqGjPwIg0BkzXofR0bjdrq8d0ul8JLnT+9ho/x8rahEN/LTHHLIwb6IYUj8X10tDZWPDk2NE5wRIy18peSXYNTeGhY1ThF75ZOAH5c1qgi0ObE+dUSqzwcDWqNDPxFvg2x67KbcMaTO6u87/mGJfuO2ekz0CgYEAwpR+tZdafTzR+MLGg55mxsfVjAWGNxp0AMwWZVTpPx1I+VgdLsMkUY8LpY2Zt8l2yInIGEzYRBNFYPrM73bW5v0bleGl60I6j3KA/Ic6RUaweycbQgMxob5PCWrMm94Jib1bGAxNU1m0Jp9rzxGUzWw3TpSw6LHNLqokwMCKG/cCgYEAtSg1oqeCvvCrIdA6AulzzWR6x2Re/Iv8MYJ5X0fNPRBHSVhwsdb2nLfjMPmLesBOPm55O/LZDFtL8unpOUc+qT8QWKAjvI0/HtYf2sec3sP/dxCYYK18grK1cvD/UAUfiljM0gAsxZRT77VbpOIMCOi9YjHoyeRgCQtxB9CuZjsCgYEAsLNfehLvpwmjeK+QzRf9J4l0AQtHPiU0sUClGfKJOrqieWUuYzftdG9d2UMFFGTNDQIqhv7J6tBBUfeQQep+8BdshKj9Hu7u9TO7tRgsr5qpS71QwJrb6JFFfzzQgL+bk800u1r4obe1pNljcxD5O6+JbkATg81rknQKmkx/XzMCgYARnyqwesjuF+0dqeqqs9jO5vJGiQ3wVRGgI0f5K7vcL8Qvb0nvErEEh6Ky9eNKeoBh9E8YtMPGPu9BXt2P8801m2vUoyc2xSqZrkyE9Jve04P7KgMYjGerMwURfD3po8XwqDisSNYSFh6gF60ledOf3jvl3GL/mJZ66sEA+JyuVwKBgGwef1FWkDTeft6VFo2obHCh8Fc8rsV2rQ0twgmA00nmuckKr5MQgyMiz2YYWanmOS18xLgl7FzvyX56clj1MvRl9xnwhSudtE4fxg6R4rzwf3jaWtAkXEHet+mqVRJgI9m5Bn8E7nVVmjgRlogZsgYq2pF3nL1sgl3ti7gOVVL6" + oauthCertificate = "MIICnzCCAYcCBgFUPZAJhjANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDDAhob2QtdGVzdDAeFw0xNjA0MjIxMDQwMzBaFw0yNjA0MjIxMDQyMTBaMBMxETAPBgNVBAMMCGhvZC10ZXN0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBFd9T/1s769tGhOMtUspP2tChKy5OWF50HkRVLny1nt12JeQUvuVSD3l7vN17hFRpMm1ktjVCTxBk5PRfPtpOcMCG2zgYbB73hIRYZaKG5X6/r2y3TllZ2UkZh0ndL+jrn1L4I2zxB5OAi3CDTxiFtjcEShAC9smjp04Omxwat53k8IxJLRgnpuC/TMbxUPHLNjuOHLLFeSN7095SuD+qzx0H7fT4sqW3+mAr7Q/kl2yq4vMXfLHt5KkOm7O5px5mRoGS4Asbkw5MQMgP618uQ9k7EQZx37jF2ol4Z7uLQWscePdWA66ajbxAtybCesNPa4uUrb1YVdx6MikWyZ0i7" +) + +func newFakeOAuthServer(t *testing.T) { + s := new(fakeOAuthServer) + r := gin.New() + r.GET("/auth/realms/hod-test/.well-known/openid-configuration", s.discoveryHandler) + r.GET("/auth/realms/hod-test/protocol/openid-connect/certs", s.keysHandler) + r.POST("/auth/realms/hod-test/protocol/openid-connect/token", s.tokenHandler) + r.POST("/auth/realms/hod-test/protocol/openid-connect/auth", s.authHandler) + + if err := r.Run("127.0.0.1:8080"); err != nil { + t.Fatalf("failed to start the fake oauth service, error: %s", err) + } +} + +func (r fakeOAuthServer) discoveryHandler(cx *gin.Context) { + cx.JSON(http.StatusOK, fakeDiscoveryResponse{ + IDTokenSigningAlgValuesSupported: []string{"RS256"}, + Issuer: "http://127.0.0.1:8080/auth/realms/hod-test", + AuthorizationEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/auth", + TokenEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/token", + RegistrationEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/clients-registrations/openid-connect", + TokenIntrospectionEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/token/introspect", + UserinfoEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/userinfo", + EndSessionEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/logout", + JwksURI: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/certs", + GrantTypesSupported: []string{"authorization_code", "implicit", "refresh_token", "password", "client_credentials"}, + ResponseModesSupported: []string{"query", "fragment", "form_post"}, + ResponseTypesSupported: []string{"code", "none", "id_token", "token", "id_token token", "code id_token", "code token", "code id_token token"}, + SubjectTypesSupported: []string{"public"}, + }) +} + +func (r fakeOAuthServer) keysHandler(cx *gin.Context) { + cx.JSON(http.StatusOK, fakeKeysResponse{ + Keys: []fakeKeyResponse{ + { + Kid: "ing3Hnuj0ciqrHCOxt__-B53jzXcdD1n1iKbX3GsD9s", + Kty: "RSA", + Alg: "RS256", + Use: "sig", + N: fakePublicKey, + E: "AQAB", + }, + }, + }) +} + +func (r fakeOAuthServer) authHandler(cx *gin.Context) { + +} + +func (r fakeOAuthServer) tokenHandler(cx *gin.Context) { + +} diff --git a/server.go b/server.go index 9395dd901623aba2c43099c45d1f3dfacdd63f46..57743135bf4d9cd456e11e4f9b2e29f129473419 100644 --- a/server.go +++ b/server.go @@ -50,7 +50,7 @@ type oauthProxy struct { // the upstream endpoint url endpoint *url.URL // the store interface - store Store + store storage } type reverseProxy interface { @@ -75,7 +75,7 @@ func newProxy(cfg *Config) (*oauthProxy, error) { log.SetLevel(log.DebugLevel) } - log.Infof("starting %s, version: %s, author: %s", prog, version, author) + log.Infof("starting %s, author: %s, version: %s, ", prog, author, version) service := &oauthProxy{config: cfg} @@ -87,7 +87,7 @@ func newProxy(cfg *Config) (*oauthProxy, error) { // step: initialize the store if any if cfg.StoreURL != "" { - if service.store, err = newStore(cfg.StoreURL); err != nil { + if service.store, err = newStorage(cfg.StoreURL); err != nil { return nil, err } } @@ -223,17 +223,33 @@ func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) { // setupReverseProxy create a reverse http proxy from the upstream // func (r *oauthProxy) setupReverseProxy(upstream *url.URL) (reverseProxy, error) { + // step: create the default dialer + dialer := (&net.Dialer{ + KeepAlive: 10 * time.Second, + Timeout: 10 * time.Second, + }).Dial + + // step: are we using a unix socket? + if upstream.Scheme == "unix" { + log.Infof("using the unix domain socket: %s for upstream", upstream.Host) + socketPath := upstream.Host + dialer = func(network, address string) (net.Conn, error) { + return net.Dial("unix", socketPath) + } + upstream.Path = "" + upstream.Host = "domain-sock" + upstream.Scheme = "http" + } + // step: create the reverse proxy proxy := httputil.NewSingleHostReverseProxy(upstream) + + // step: customize the http transport proxy.Transport = &http.Transport{ - Dial: (&net.Dialer{ - KeepAlive: 10 * time.Second, - Timeout: 10 * time.Second, - }).Dial, + Dial: dialer, TLSClientConfig: &tls.Config{ InsecureSkipVerify: r.config.SkipUpstreamTLSVerify, }, - DisableKeepAlives: !r.config.Keepalives, - TLSHandshakeTimeout: 10 * time.Second, + DisableKeepAlives: !r.config.Keepalives, } return proxy, nil @@ -253,7 +269,7 @@ func (r oauthProxy) setupRouter() error { r.router.Use(r.securityHandler()) } // step: add the routing - oauth := r.router.Group(oauthURL).Use(r.crossSiteHandler()) + oauth := r.router.Group(oauthURL).Use(r.crossOriginResourceHandler(r.config.CrossOrigin)) { oauth.GET(authorizationURL, r.oauthAuthorizationHandler) oauth.GET(callbackURL, r.oauthCallbackHandler) diff --git a/server_test.go b/server_test.go index 56364000423f2f4acec2f3c88ea2d81d190b0ee4..78c5d58769d4d75d9e30bd913688e58b64dd1f26 100644 --- a/server_test.go +++ b/server_test.go @@ -94,7 +94,7 @@ func newFakeKeycloakConfig(t *testing.T) *Config { Roles: []string{}, }, }, - CORS: &CORS{}, + CrossOrigin: CORS{}, } } @@ -212,6 +212,15 @@ func newFakeGinContext(method, uri string) *gin.Context { } } +func newFakeGinContextWithCookies(method, url string, cookies []*http.Cookie) *gin.Context { + cx := newFakeGinContext(method, url) + for _, x := range cookies { + cx.Request.AddCookie(x) + } + + return cx +} + func newFakeJWTToken(t *testing.T, claims jose.Claims) *jose.JWT { token, err := jose.NewJWT( jose.JOSEHeader{"alg": "RS256"}, claims, diff --git a/session_test.go b/session_test.go index 77ba019769173ac15eef6c8ba87b639e5efee6ac..92aa02c731c9fd45b244ed512d7ef5528cb4f43c 100644 --- a/session_test.go +++ b/session_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" ) func TestGetSessionToken(t *testing.T) { @@ -65,3 +66,51 @@ func TestGetSessionToken(t *testing.T) { } } } + +func TestGetRefreshTokenFromCookie(t *testing.T) { + cases := []struct { + Cookies []*http.Cookie + Expected string + Ok bool + }{ + { + Cookies: []*http.Cookie{}, + }, + { + Cookies: []*http.Cookie{ + { + Name: "not_a_session_cookie", + Path: "/", + Domain: "127.0.0.1", + }, + }, + }, + { + Cookies: []*http.Cookie{ + { + Name: cookieRefreshToken, + Path: "/", + Domain: "127.0.0.1", + Value: "refresh_token", + }, + }, + Expected: "refresh_token", + Ok: true, + }, + } + + for i, x := range cases { + context := newFakeGinContextWithCookies("GET", "/", x.Cookies) + + token, err := getRefreshTokenFromCookie(context) + if err != nil && x.Ok { + t.Errorf("case %d, should not have thrown an error: %s, headers: %v", i, err, context.Writer.Header()) + continue + } + if err == nil && !x.Ok { + t.Errorf("case %d, should have thrown an error", i) + continue + } + assert.Equal(t, x.Expected, token, "case %d, expected token: %v, got: %v", x.Expected, token) + } +} diff --git a/store_boltdb.go b/store_boltdb.go index d3333bcda89476dfb495a71a3fbe4ce1ca5bbc96..a6fbf3fd53fa626ec07bf4b68295819bac3f9c9e 100644 --- a/store_boltdb.go +++ b/store_boltdb.go @@ -41,7 +41,7 @@ type boltdbStore struct { client *bolt.DB } -func newBoltDBStore(location *url.URL) (Store, error) { +func newBoltDBStore(location *url.URL) (storage, error) { // step: drop the initial slash path := strings.TrimPrefix(location.Path, "/") diff --git a/store_redis.go b/store_redis.go index 650998af0cef954042190521a81579fbf876dd1f..aeb92f621e1e6b9a68ca1aeb39daa17b21d7aaf6 100644 --- a/store_redis.go +++ b/store_redis.go @@ -28,16 +28,20 @@ type redisStore struct { } // newRedisStore creates a new redis store -func newRedisStore(location *url.URL) (Store, error) { +func newRedisStore(location *url.URL) (storage, error) { log.Infof("creating a redis client for store: %s", location.Host) + // step: get any password + password := "" + if location.User != nil { + password, _ = location.User.Password() + } + // step: parse the url notation client := redis.NewClient(&redis.Options{ - Addr: location.Host, - DB: 0, - DialTimeout: 10 * time.Second, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, + Addr: location.Host, + DB: 0, + Password: password, }) return redisStore{ diff --git a/stores.go b/stores.go index 815575c5b8498190df08ee92e2fab680b9f9993a..616253ae0498bfba78ab4827d5be1127ccb45fab 100644 --- a/stores.go +++ b/stores.go @@ -20,9 +20,9 @@ import ( "net/url" ) -// newStore creates the store client for use -func newStore(location string) (Store, error) { - var store Store +// newStorage creates the store client for use +func newStorage(location string) (storage, error) { + var store storage var err error u, err := url.Parse(location)