Skip to content
Snippets Groups Projects
Commit ab17cf61 authored by Rohith's avatar Rohith
Browse files

- cleaning and reducing the amount of files, perfer to use just handlers and middleware (#53)

- change config option in json/yaml of clientid to client-id
- added additional unit test for the sessions and cookie methods
- added support for password for redis
- added the ability to proxy to a unix socket
- cleaned up the handlers and middleware into respective files, i dont like lots of files laying around
parent 61c6fe90
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,9 @@ bin/ ...@@ -4,6 +4,9 @@ bin/
release/ release/
cover.html cover.html
cover.out cover.out
tests/db.bolt
test.sock
tests/redis.conf
*.iml *.iml
config.yml config.yml
......
#### **1.0.3 (April 30th, 2016)**
FIXES:
* Fixes the cookie sessions expiraton
FEATURES:
* Adding a idle duration configuration option which controls the expiration of access token cookie and thus session.
If the session is not used within that period, the session is removed.
* The upstream endpoint has also be a unix socket
BREAKING CHANGES:
* Change the client id in json/yaml config file from clientid -> client-id
#### **1.0.2 (April 22th, 2016)** #### **1.0.2 (April 22th, 2016)**
FIXES: FIXES:
......
NAME=keycloak-proxy NAME=keycloak-proxy
AUTHOR=gambol99 AUTHOR=gambol99
HARDWARE=$(shell uname -m)
REGISTRY=docker.io REGISTRY=docker.io
GOVERSION=1.6.0 GOVERSION=1.6.0
SUDO=sudo SUDO=sudo
GIT_COMMIT=$(shell git log --pretty=format:'%h' -n 1)
ROOT_DIR=${PWD} ROOT_DIR=${PWD}
HARDWARE=$(shell uname -m)
GIT_SHA=$(shell git --no-pager describe --tags --always --dirty)
VERSION=$(shell awk '/version.*=/ { print $$3 }' doc.go | sed 's/"//g') VERSION=$(shell awk '/version.*=/ { print $$3 }' doc.go | sed 's/"//g')
DEPS=$(shell go list -f '{{range .TestImports}}{{.}} {{end}}' ./...) DEPS=$(shell go list -f '{{range .TestImports}}{{.}} {{end}}' ./...)
PACKAGES=$(shell go list ./...) PACKAGES=$(shell go list ./...)
...@@ -20,12 +20,15 @@ golang: ...@@ -20,12 +20,15 @@ golang:
@echo "--> Go Version" @echo "--> Go Version"
@go version @go version
build: version:
@sed -i "s/const gitSHA =.*/const gitSHA = \"${GIT_SHA}\"/" doc.go
build: version
@echo "--> Compiling the project" @echo "--> Compiling the project"
mkdir -p bin mkdir -p bin
godep go build -o bin/${NAME} godep go build -o bin/${NAME}
static: golang deps static: version golang deps
@echo "--> Compiling the static binary" @echo "--> Compiling the static binary"
mkdir -p bin mkdir -p bin
CGO_ENABLED=0 GOOS=linux godep go build -a -tags netgo -ldflags '-w' -o bin/${NAME} CGO_ENABLED=0 GOOS=linux godep go build -a -tags netgo -ldflags '-w' -o bin/${NAME}
......
...@@ -37,7 +37,7 @@ func newDefaultConfig() *Config { ...@@ -37,7 +37,7 @@ func newDefaultConfig() *Config {
TagData: make(map[string]string, 0), TagData: make(map[string]string, 0),
ClaimsMatch: make(map[string]string, 0), ClaimsMatch: make(map[string]string, 0),
Header: make(map[string]string, 0), Header: make(map[string]string, 0),
CORS: &CORS{}, CrossOrigin: CORS{},
SkipUpstreamTLSVerify: true, SkipUpstreamTLSVerify: true,
} }
} }
...@@ -155,12 +155,18 @@ func readOptions(cx *cli.Context, config *Config) (err error) { ...@@ -155,12 +155,18 @@ func readOptions(cx *cli.Context, config *Config) (err error) {
if cx.IsSet("upstream-keepalives") { if cx.IsSet("upstream-keepalives") {
config.Keepalives = cx.Bool("upstream-keepalives") config.Keepalives = cx.Bool("upstream-keepalives")
} }
if cx.IsSet("idle-duration") {
config.IdleDuration = cx.Duration("idle-duration")
}
if cx.IsSet("skip-token-verification") { if cx.IsSet("skip-token-verification") {
config.SkipTokenVerification = cx.Bool("skip-token-verification") config.SkipTokenVerification = cx.Bool("skip-token-verification")
} }
if cx.IsSet("skip-upstream-tls-verify") { if cx.IsSet("skip-upstream-tls-verify") {
config.SkipUpstreamTLSVerify = cx.Bool("skip-upstream-tls-verify") config.SkipUpstreamTLSVerify = cx.Bool("skip-upstream-tls-verify")
} }
if cx.IsSet("enable-refresh-tokens") {
config.EnableRefreshTokens = cx.Bool("enable-refresh-tokens")
}
if cx.IsSet("encryption-key") { if cx.IsSet("encryption-key") {
config.EncryptionKey = cx.String("encryption-key") config.EncryptionKey = cx.String("encryption-key")
} }
...@@ -210,22 +216,22 @@ func readOptions(cx *cli.Context, config *Config) (err error) { ...@@ -210,22 +216,22 @@ func readOptions(cx *cli.Context, config *Config) (err error) {
config.Hostnames = cx.StringSlice("hostname") config.Hostnames = cx.StringSlice("hostname")
} }
if cx.IsSet("cors-origins") { if cx.IsSet("cors-origins") {
config.CORS.Origins = cx.StringSlice("cors-origins") config.CrossOrigin.Origins = cx.StringSlice("cors-origins")
} }
if cx.IsSet("cors-methods") { if cx.IsSet("cors-methods") {
config.CORS.Methods = cx.StringSlice("cors-methods") config.CrossOrigin.Methods = cx.StringSlice("cors-methods")
} }
if cx.IsSet("cors-headers") { if cx.IsSet("cors-headers") {
config.CORS.Headers = cx.StringSlice("cors-headers") config.CrossOrigin.Headers = cx.StringSlice("cors-headers")
} }
if cx.IsSet("cors-exposed-headers") { if cx.IsSet("cors-exposed-headers") {
config.CORS.ExposedHeaders = cx.StringSlice("cors-exposed-headers") config.CrossOrigin.ExposedHeaders = cx.StringSlice("cors-exposed-headers")
} }
if cx.IsSet("cors-max-age") { if cx.IsSet("cors-max-age") {
config.CORS.MaxAge = cx.Duration("cors-max-age") config.CrossOrigin.MaxAge = cx.Duration("cors-max-age")
} }
if cx.IsSet("cors-credentials") { if cx.IsSet("cors-credentials") {
config.CORS.Credentials = cx.BoolT("cors-credentials") config.CrossOrigin.Credentials = cx.BoolT("cors-credentials")
} }
if cx.IsSet("tag") { if cx.IsSet("tag") {
config.TagData, err = decodeKeyPairs(cx.StringSlice("tag")) config.TagData, err = decodeKeyPairs(cx.StringSlice("tag"))
...@@ -302,6 +308,10 @@ func getOptions() []cli.Flag { ...@@ -302,6 +308,10 @@ func getOptions() []cli.Flag {
Name: "discovery-url", Name: "discovery-url",
Usage: "the discovery url to retrieve the openid configuration", Usage: "the discovery url to retrieve the openid configuration",
}, },
cli.DurationFlag{
Name: "idle-duration",
Usage: "the expiration of the access token cookie, if not used within this time its removed",
},
cli.StringFlag{ cli.StringFlag{
Name: "upstream-url", Name: "upstream-url",
Usage: "the url for the upstream endpoint you wish to proxy to", Usage: "the url for the upstream endpoint you wish to proxy to",
...@@ -309,12 +319,16 @@ func getOptions() []cli.Flag { ...@@ -309,12 +319,16 @@ func getOptions() []cli.Flag {
}, },
cli.StringFlag{ cli.StringFlag{
Name: "revocation-url", Name: "revocation-url",
Usage: "the url for the revocation endpoint to revoke refresh token, not all providers support the revocation_endpoint", Usage: "the url for the revocation endpoint to revoke refresh token",
Value: "/oauth2/revoke", Value: "/oauth2/revoke",
}, },
cli.BoolTFlag{ cli.BoolTFlag{
Name: "upstream-keepalives", Name: "upstream-keepalives",
Usage: "enables or disables the keepalive connections for upstream endpoint (defaults true)", Usage: "enables or disables the keepalive connections for upstream endpoint",
},
cli.BoolFlag{
Name: "enable-refresh-tokens",
Usage: "enables the handling of the refresh tokens",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "encryption-key", Name: "encryption-key",
...@@ -322,15 +336,15 @@ func getOptions() []cli.Flag { ...@@ -322,15 +336,15 @@ func getOptions() []cli.Flag {
}, },
cli.StringFlag{ cli.StringFlag{
Name: "store-url", Name: "store-url",
Usage: "the store url to use for storing the refresh tokens, i.e. redis://127.0.0.1:6379, file:///etc/tokens.file", Usage: "url for the storage subsystem, e.g redis://127.0.0.1:6379, file:///etc/tokens.file",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "no-redirects", Name: "no-redirects",
Usage: "do not have back redirects when no authentication is present, simple reply with 401 code", Usage: "do not have back redirects when no authentication is present, 401 them",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "redirection-url", Name: "redirection-url",
Usage: fmt.Sprintf("the redirection url, namely the site url, note: %s will be added to it", oauthURL), Usage: fmt.Sprintf("redirection url for the oauth callback url (%s is added)", oauthURL),
}, },
cli.StringSliceFlag{ cli.StringSliceFlag{
Name: "hostname", Name: "hostname",
...@@ -358,7 +372,7 @@ func getOptions() []cli.Flag { ...@@ -358,7 +372,7 @@ func getOptions() []cli.Flag {
}, },
cli.StringSliceFlag{ cli.StringSliceFlag{
Name: "claim", Name: "claim",
Usage: "a series of key pair values which must match the claims in the token present e.g. aud=myapp, iss=http://example.com etcd", Usage: "keypair values for matching access token claims e.g. aud=myapp, iss=http://example.*",
}, },
cli.StringSliceFlag{ cli.StringSliceFlag{
Name: "resource", Name: "resource",
...@@ -374,11 +388,11 @@ func getOptions() []cli.Flag { ...@@ -374,11 +388,11 @@ func getOptions() []cli.Flag {
}, },
cli.StringSliceFlag{ cli.StringSliceFlag{
Name: "tag", Name: "tag",
Usage: "a keypair tag which is passed to the templates when render, i.e. title='My Page',site='my name' etc", Usage: "keypair's passed to the templates at render,e.g title='My Page'",
}, },
cli.StringSliceFlag{ cli.StringSliceFlag{
Name: "cors-origins", Name: "cors-origins",
Usage: "a set of origins to add to the CORS access control (Access-Control-Allow-Origin)", Usage: "list of origins to add to the CORE origins control (Access-Control-Allow-Origin)",
}, },
cli.StringSliceFlag{ cli.StringSliceFlag{
Name: "cors-methods", Name: "cors-methods",
...@@ -406,11 +420,7 @@ func getOptions() []cli.Flag { ...@@ -406,11 +420,7 @@ func getOptions() []cli.Flag {
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "skip-token-verification", Name: "skip-token-verification",
Usage: "testing purposes ONLY, the option allows you to bypass the token verification, expiration and roles are still enforced", Usage: "TESTING ONLY; bypass's token verification, expiration and roles enforced",
},
cli.BoolFlag{
Name: "proxy-protocol",
Usage: "switches on proxy protocol support on the listen (not supported yet)",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "offline-session", Name: "offline-session",
......
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
# is the url for retrieve the openid configuration - normally the <server>/auth/realm/<realm_name> # is the url for retrieve the openid configuration - normally the <server>/auth/realm/<realm_name>
discovery-url: https://keycloak.example.com/auth/realms/commons discovery-url: https://keycloak.example.com/auth/realms/commons
# the client id for the 'client' application # the client id for the 'client' application
clientid: <CLIENT_ID> client-id: <CLIENT_ID>
# the secret associated to the 'client' application # the secret associated to the 'client' application
client-secret: <CLIENT_SECRET> client-secret: <CLIENT_SECRET>
# the interface definition you wish the proxy to listen, all interfaces is specified as ':<port>' # the interface definition you wish the proxy to listen, all interfaces is specified as ':<port>'
listen: 127.0.0.1:3000 listen: 127.0.0.1:3000
# whether to request offline access and use a refresh token # whether to request offline access and use a refresh token
enable-refresh-tokens: true enable-refresh-tokens: true
# the max amount of time a session can stay alive without being used
idle-duration: 24h
# log all incoming requests # log all incoming requests
log-requests: true log-requests: true
# log in json format # log in json format
...@@ -31,8 +33,7 @@ upstream: http://127.0.0.1:80 ...@@ -31,8 +33,7 @@ upstream: http://127.0.0.1:80
# upstream-keepalives specified wheather you want keepalive on the upstream endpoint # upstream-keepalives specified wheather you want keepalive on the upstream endpoint
upstream-keepalives: true upstream-keepalives: true
# additional scopes to add to add to the default (openid+email+profile) # additional scopes to add to add to the default (openid+email+profile)
scopes: scopes: []
- vpn-user
# enables a more extra secuirty features # enables a more extra secuirty features
enable-security-filter: true enable-security-filter: true
# a map of claims that MUST exist in the token presented and the value is it MUST match # a map of claims that MUST exist in the token presented and the value is it MUST match
......
...@@ -37,14 +37,14 @@ func TestReadConfiguration(t *testing.T) { ...@@ -37,14 +37,14 @@ func TestReadConfiguration(t *testing.T) {
{ {
Content: ` Content: `
discovery_url: https://keyclock.domain.com/ discovery_url: https://keyclock.domain.com/
clientid: <client_id> client-id: <client_id>
secret: <secret> secret: <secret>
`, `,
}, },
{ {
Content: ` Content: `
discovery_url: https://keyclock.domain.com discovery_url: https://keyclock.domain.com
clientid: <client_id> client-id: <client_id>
secret: <secret> secret: <secret>
upstream: http://127.0.0.1:8080 upstream: http://127.0.0.1:8080
redirection_url: http://127.0.0.1:3000 redirection_url: http://127.0.0.1:3000
......
...@@ -27,17 +27,16 @@ import ( ...@@ -27,17 +27,16 @@ import (
// //
// dropCookie drops a cookie into the response // dropCookie drops a cookie into the response
// //
func dropCookie(cx *gin.Context, name, value string, expires time.Time) { func dropCookie(cx *gin.Context, name, value string, duration time.Duration) {
cookie := &http.Cookie{ cookie := &http.Cookie{
Name: name, Name: name,
Domain: strings.Split(cx.Request.Host, ":")[0], Domain: strings.Split(cx.Request.Host, ":")[0],
Path: "/", Path: "/",
HttpOnly: true,
Secure: true, Secure: true,
Value: value, Value: value,
} }
if !expires.IsZero() { if duration != 0 {
cookie.Expires = expires cookie.Expires = time.Now().Add(duration)
} }
http.SetCookie(cx.Writer, cookie) http.SetCookie(cx.Writer, cookie)
...@@ -46,15 +45,15 @@ func dropCookie(cx *gin.Context, name, value string, expires time.Time) { ...@@ -46,15 +45,15 @@ func dropCookie(cx *gin.Context, name, value string, expires time.Time) {
// //
// dropAccessTokenCookie drops a access token cookie into the response // dropAccessTokenCookie drops a access token cookie into the response
// //
func dropAccessTokenCookie(cx *gin.Context, token jose.JWT) { func dropAccessTokenCookie(cx *gin.Context, token jose.JWT, duration time.Duration) {
dropCookie(cx, cookieAccessToken, token.Encode(), time.Time{}) dropCookie(cx, cookieAccessToken, token.Encode(), duration)
} }
// //
// dropRefreshTokenCookie drops a refresh token cookie into the response // dropRefreshTokenCookie drops a refresh token cookie into the response
// //
func dropRefreshTokenCookie(cx *gin.Context, token string, expires time.Time) { func dropRefreshTokenCookie(cx *gin.Context, token string, duration time.Duration) {
dropCookie(cx, cookieRefreshToken, token, expires) dropCookie(cx, cookieRefreshToken, token, duration)
} }
// //
...@@ -69,12 +68,12 @@ func clearAllCookies(cx *gin.Context) { ...@@ -69,12 +68,12 @@ func clearAllCookies(cx *gin.Context) {
// clearRefreshSessionCookie clears the session cookie // clearRefreshSessionCookie clears the session cookie
// //
func clearRefreshTokenCookie(cx *gin.Context) { func clearRefreshTokenCookie(cx *gin.Context) {
dropCookie(cx, cookieRefreshToken, "", time.Now().Add(-1*time.Hour)) dropCookie(cx, cookieRefreshToken, "", time.Duration(-10*time.Hour))
} }
// //
// clearAccessTokenCookie clears the session cookie // clearAccessTokenCookie clears the session cookie
// //
func clearAccessTokenCookie(cx *gin.Context) { func clearAccessTokenCookie(cx *gin.Context) {
dropCookie(cx, cookieAccessToken, "", time.Now().Add(-1*time.Hour)) dropCookie(cx, cookieAccessToken, "", time.Duration(-10*time.Hour))
} }
...@@ -14,3 +14,48 @@ limitations under the License. ...@@ -14,3 +14,48 @@ limitations under the License.
*/ */
package main package main
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDropCookie(t *testing.T) {
context := newFakeGinContext("GET", "/admin")
dropCookie(context, "test-cookie", "test-value", 0)
assert.Equal(t, context.Writer.Header().Get("Set-Cookie"),
"test-cookie=test-value; Path=/; Domain=127.0.0.1; Secure",
"we have not set the cookie, headers: %v", context.Writer.Header())
context = newFakeGinContext("GET", "/admin")
dropCookie(context, "test-cookie", "test-value", 0)
assert.NotEqual(t, context.Writer.Header().Get("Set-Cookie"),
"test-cookie=test-value; Path=/; Domain=127.0.0.2; HttpOnly; Secure",
"we have not set the cookie, headers: %v", context.Writer.Header())
}
func TestClearAccessTokenCookie(t *testing.T) {
context := newFakeGinContext("GET", "/admin")
clearAccessTokenCookie(context)
assert.Contains(t, context.Writer.Header().Get("Set-Cookie"),
"kc-access=; Path=/; Domain=127.0.0.1; Expires=",
"we have not cleared the, headers: %v", context.Writer.Header())
}
func TestClearRefreshAccessTokenCookie(t *testing.T) {
context := newFakeGinContext("GET", "/admin")
clearRefreshTokenCookie(context)
assert.Contains(t, context.Writer.Header().Get("Set-Cookie"),
"kc-state=; Path=/; Domain=127.0.0.1; Expires=",
"we have not cleared the, headers: %v", context.Writer.Header())
}
func TestClearAllCookies(t *testing.T) {
context := newFakeGinContext("GET", "/admin")
clearAllCookies(context)
assert.Contains(t, context.Writer.Header().Get("Set-Cookie"),
"kc-access=; Path=/; Domain=127.0.0.1; Expires=",
"we have not cleared the, headers: %v", context.Writer.Header())
}
...@@ -20,9 +20,11 @@ import ( ...@@ -20,9 +20,11 @@ import (
"time" "time"
) )
const gitSHA = "v1.0.3-2-g0082034-dirty"
const ( const (
prog = "keycloak-proxy" prog = "keycloak-proxy"
version = "v1.0.3" version = "v1.0.3" + " (git+sha: " + gitSHA + ")"
author = "Rohith" author = "Rohith"
email = "gambol99@gmail.com" email = "gambol99@gmail.com"
description = "is a proxy using the keycloak service for auth and authorization" description = "is a proxy using the keycloak service for auth and authorization"
...@@ -101,7 +103,7 @@ type Config struct { ...@@ -101,7 +103,7 @@ type Config struct {
// DiscoveryURL is the url for the keycloak server // DiscoveryURL is the url for the keycloak server
DiscoveryURL string `json:"discovery-url" yaml:"discovery-url"` DiscoveryURL string `json:"discovery-url" yaml:"discovery-url"`
// ClientID is the client id // ClientID is the client id
ClientID string `json:"clientid" yaml:"clientid"` ClientID string `json:"client-id" yaml:"client-id"`
// ClientSecret is the secret for AS // ClientSecret is the secret for AS
ClientSecret string `json:"client-secret" yaml:"client-secret"` ClientSecret string `json:"client-secret" yaml:"client-secret"`
// RevocationEndpoint is the token revocation endpoint to revoke refresh tokens // RevocationEndpoint is the token revocation endpoint to revoke refresh tokens
...@@ -114,6 +116,8 @@ type Config struct { ...@@ -114,6 +116,8 @@ type Config struct {
EnableSecurityFilter bool `json:"enable-security-filter" yaml:"enable-security-filter"` EnableSecurityFilter bool `json:"enable-security-filter" yaml:"enable-security-filter"`
// EnableRefreshTokens indicate's you wish to ignore using refresh tokens and re-auth on expireation of access token // EnableRefreshTokens indicate's you wish to ignore using refresh tokens and re-auth on expireation of access token
EnableRefreshTokens bool `json:"enable-refresh-tokens" yaml:"enable-refresh-tokens"` EnableRefreshTokens bool `json:"enable-refresh-tokens" yaml:"enable-refresh-tokens"`
// IdleDuration is the max amount of time a session can last without being used
IdleDuration time.Duration `json:"idle-duration" yaml:"idle-duration"`
// EncryptionKey is the encryption key used to encrypt the refresh token // EncryptionKey is the encryption key used to encrypt the refresh token
EncryptionKey string `json:"encryption-key" yaml:"encryption-key"` EncryptionKey string `json:"encryption-key" yaml:"encryption-key"`
// ClaimsMatch is a series of checks, the claims in the token must match those here // ClaimsMatch is a series of checks, the claims in the token must match those here
...@@ -136,8 +140,8 @@ type Config struct { ...@@ -136,8 +140,8 @@ type Config struct {
Upstream string `json:"upstream" yaml:"upstream"` Upstream string `json:"upstream" yaml:"upstream"`
// TagData is passed to the templates // TagData is passed to the templates
TagData map[string]string `json:"tag-data" yaml:"tag-data"` TagData map[string]string `json:"tag-data" yaml:"tag-data"`
// CORS permits adding headers to the /oauth handlers // CrossOrigin permits adding headers to the /oauth handlers
CORS *CORS `json:"cors" yaml:"cors"` CrossOrigin CORS `json:"cors" yaml:"cors"`
// Header permits adding customs headers across the board // Header permits adding customs headers across the board
Header map[string]string `json:"headers" yaml:"headers"` Header map[string]string `json:"headers" yaml:"headers"`
// Scopes is a list of scope we should request // Scopes is a list of scope we should request
...@@ -149,7 +153,7 @@ type Config struct { ...@@ -149,7 +153,7 @@ type Config struct {
// ForbiddenPage is a access forbidden page // ForbiddenPage is a access forbidden page
ForbiddenPage string `json:"forbidden-page" yaml:"forbidden-page"` ForbiddenPage string `json:"forbidden-page" yaml:"forbidden-page"`
// SkipTokenVerification tells the service to skipp verifying the access token - for testing purposes // SkipTokenVerification tells the service to skipp verifying the access token - for testing purposes
SkipTokenVerification bool SkipTokenVerification bool `json:"skip-token-verification" yaml:"skip-token-verification"`
// Verbose switches on debug logging // Verbose switches on debug logging
Verbose bool `json:"verbose" yaml:"verbose"` Verbose bool `json:"verbose" yaml:"verbose"`
// Hostname is a list of hostname's the service should response to // Hostname is a list of hostname's the service should response to
...@@ -158,9 +162,9 @@ type Config struct { ...@@ -158,9 +162,9 @@ type Config struct {
StoreURL string `json:"store-url" yaml:"store-url"` StoreURL string `json:"store-url" yaml:"store-url"`
} }
// Store is used to hold the offline refresh token, assuming you don't want to use // store is used to hold the offline refresh token, assuming you don't want to use
// the default practice of a encrypted cookie // the default practice of a encrypted cookie
type Store interface { type storage interface {
// Add the token to the store // Add the token to the store
Set(string, string) error Set(string, string) error
// Get retrieves a token from the store // Get retrieves a token from the store
......
...@@ -21,191 +21,13 @@ import ( ...@@ -21,191 +21,13 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"strings"
"time" "time"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/coreos/go-oidc/jose"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
//
// authenticationHandler is responsible for verifying the access token
//
func (r *oauthProxy) authenticationHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
// step: is authentication required on this uri?
if _, found := cx.Get(cxEnforce); !found {
log.Debugf("skipping the authentication handler, resource not protected")
cx.Next()
return
}
// step: grab the user identity from the request
user, err := getIdentity(cx)
if err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("failed to get session, redirecting for authorization")
r.redirectToAuthorization(cx)
return
}
// step: inject the user into the context
cx.Set(userContextName, user)
// step: verify the access token
if r.config.SkipTokenVerification {
log.Warnf("skip token verification enabled, skipping verification process - FOR TESTING ONLY")
if user.isExpired() {
log.WithFields(log.Fields{
"username": user.name,
"expired_on": user.expiresAt.String(),
}).Errorf("the session has expired and verification switch off")
r.redirectToAuthorization(cx)
}
return
}
// step: verify the access token
if err := verifyToken(r.client, user.token); err != nil {
// step: if the error post verification is anything other than a token expired error
// we immediately throw an access forbidden - as there is something messed up in the token
if err != ErrAccessTokenExpired {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("verification of the access token failed")
r.accessForbidden(cx)
return
}
// step: are we refreshing the access tokens?
if !r.config.EnableRefreshTokens {
log.WithFields(log.Fields{
"email": user.name,
"expired_on": user.expiresAt.String(),
}).Errorf("the session has expired and access token refreshing is disabled")
r.redirectToAuthorization(cx)
return
}
// step: we do not refresh bearer token requests
if user.isBearer() {
log.WithFields(log.Fields{
"email": user.name,
"expired_on": user.expiresAt.String(),
}).Errorf("the session has expired and we are using bearer tokens")
r.redirectToAuthorization(cx)
return
}
log.WithFields(log.Fields{
"email": user.email,
"client_ip": cx.ClientIP(),
}).Infof("the accces token for user: %s has expired, attemping to refresh the token", user.email)
// step: check if the user has refresh token
rToken, err := r.retrieveRefreshToken(cx, user)
if err != nil {
log.WithFields(log.Fields{
"email": user.email,
"error": err.Error(),
}).Errorf("unable to find a refresh token for the client: %s", user.email)
r.redirectToAuthorization(cx)
return
}
log.WithFields(log.Fields{
"email": user.email,
}).Infof("found a refresh token, attempting to refresh access token for user: %s", user.email)
// step: attempts to refresh the access token
token, expires, err := refreshToken(r.client, rToken)
if err != nil {
// step: has the refresh token expired
switch err {
case ErrRefreshTokenExpired:
log.WithFields(log.Fields{"token": token}).Warningf("the refresh token has expired")
clearAllCookies(cx)
default:
log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to refresh the access token")
}
r.redirectToAuthorization(cx)
return
}
// step: inject the refreshed access token
log.WithFields(log.Fields{
"email": user.email,
"access_expires_in": expires.Sub(time.Now()).String(),
}).Infof("injecting refreshed access token, expires on: %s", expires.Format(time.RFC1123))
// step: clear the cookie up
dropAccessTokenCookie(cx, token)
if r.useStore() {
go func(t jose.JWT, rt string) {
// step: the access token has been updated, we need to delete old reference and update the store
if err := r.DeleteRefreshToken(t); err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("unable to delete the old refresh tokem from store")
}
// step: store the new refresh token reference place the session in the store
if err := r.StoreRefreshToken(t, rt); err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("failed to place the refresh token in the store")
return
}
}(user.token, rToken)
}
// step: update the with the new access token
user.token = token
// step: inject the user into the context
cx.Set(userContextName, user)
}
cx.Next()
}
}
//
// retrieveRefreshToken retrieves the refresh token from store or c
//
func (r oauthProxy) retrieveRefreshToken(cx *gin.Context, user *userContext) (string, error) {
var token string
var err error
// step: get the refresh token from the store or cookie
switch r.useStore() {
case true:
token, err = r.GetRefreshToken(user.token)
default:
token, err = getRefreshTokenFromCookie(cx)
}
// step: decode the cookie
if err != nil {
return token, err
}
return decodeText(token, r.config.EncryptionKey)
}
// //
// oauthAuthorizationHandler is responsible for performing the redirection to oauth provider // oauthAuthorizationHandler is responsible for performing the redirection to oauth provider
// //
...@@ -325,10 +147,11 @@ func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) { ...@@ -325,10 +147,11 @@ func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) {
"email": identity.Email, "email": identity.Email,
"expires": identity.ExpiresAt.Format(time.RFC822Z), "expires": identity.ExpiresAt.Format(time.RFC822Z),
"duration": identity.ExpiresAt.Sub(time.Now()).String(), "duration": identity.ExpiresAt.Sub(time.Now()).String(),
"idle": r.config.IdleDuration.String(),
}).Infof("issuing a new access token for user, email: %s", identity.Email) }).Infof("issuing a new access token for user, email: %s", identity.Email)
// step: drop's a session cookie with the access token // step: drop's a session cookie with the access token
dropAccessTokenCookie(cx, session) dropAccessTokenCookie(cx, session, r.config.IdleDuration)
// step: does the response has a refresh token and we are NOT ignore refresh tokens? // step: does the response has a refresh token and we are NOT ignore refresh tokens?
if r.config.EnableRefreshTokens && response.RefreshToken != "" { if r.config.EnableRefreshTokens && response.RefreshToken != "" {
...@@ -352,7 +175,7 @@ func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) { ...@@ -352,7 +175,7 @@ func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) {
}).Warnf("failed to save the refresh token in the store") }).Warnf("failed to save the refresh token in the store")
} }
default: default:
dropRefreshTokenCookie(cx, encrypted, time.Time{}) dropRefreshTokenCookie(cx, encrypted, r.config.IdleDuration*2)
} }
} }
...@@ -501,3 +324,115 @@ func (r oauthProxy) logoutHandler(cx *gin.Context) { ...@@ -501,3 +324,115 @@ func (r oauthProxy) logoutHandler(cx *gin.Context) {
cx.AbortWithStatus(http.StatusOK) cx.AbortWithStatus(http.StatusOK)
} }
//
// proxyHandler is responsible to proxy the requests on to the upstream endpoint
//
func (r *oauthProxy) proxyHandler(cx *gin.Context) {
// step: double check, if enforce is true and no user context it's a internal error
if _, found := cx.Get(cxEnforce); found {
if _, found := cx.Get(userContextName); !found {
log.Errorf("no user context found for a secure request")
cx.AbortWithStatus(http.StatusInternalServerError)
return
}
}
// step: retrieve the user context if any
if user, found := cx.Get(userContextName); found {
id := user.(*userContext)
cx.Request.Header.Add("X-Auth-UserId", id.id)
cx.Request.Header.Add("X-Auth-Subject", id.preferredName)
cx.Request.Header.Add("X-Auth-Username", id.name)
cx.Request.Header.Add("X-Auth-Email", id.email)
cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String())
cx.Request.Header.Add("X-Auth-Token", id.token.Encode())
cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ","))
}
// step: add the default headers
cx.Request.Header.Add("X-Forwarded-For", cx.Request.RemoteAddr)
cx.Request.Header.Set("X-Forwarded-Agent", prog)
cx.Request.Header.Set("X-Forwarded-Agent-Version", version)
// step: is this connection upgrading?
if isUpgradedConnection(cx.Request) {
log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade))
if err := tryUpdateConnection(cx, r.endpoint); err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection")
cx.AbortWithStatus(http.StatusInternalServerError)
return
}
cx.Abort()
return
}
r.upstream.ServeHTTP(cx.Writer, cx.Request)
}
//
// expirationHandler checks if the token has expired
//
func (r *oauthProxy) expirationHandler(cx *gin.Context) {
// step: get the access token from the request
user, err := getIdentity(cx)
if err != nil {
cx.AbortWithError(http.StatusUnauthorized, err)
return
}
// step: check the access is not expired
if user.isExpired() {
cx.AbortWithError(http.StatusUnauthorized, err)
return
}
cx.AbortWithStatus(http.StatusOK)
}
//
// tokenHandler display access token to screen
//
func (r *oauthProxy) tokenHandler(cx *gin.Context) {
// step: extract the access token from the request
user, err := getIdentity(cx)
if err != nil {
cx.AbortWithError(http.StatusBadRequest, fmt.Errorf("unable to retrieve session, error: %s", err))
return
}
// step: write the json content
cx.Writer.Header().Set("Content-Type", "application/json")
cx.String(http.StatusOK, fmt.Sprintf("%s", user.token.Payload))
}
//
// healthHandler is a health check handler for the service
//
func (r *oauthProxy) healthHandler(cx *gin.Context) {
cx.String(http.StatusOK, "OK")
}
//
// retrieveRefreshToken retrieves the refresh token from store or c
//
func (r oauthProxy) retrieveRefreshToken(cx *gin.Context, user *userContext) (string, error) {
var token string
var err error
// step: get the refresh token from the store or cookie
switch r.useStore() {
case true:
token, err = r.GetRefreshToken(user.token)
default:
token, err = getRefreshTokenFromCookie(cx)
}
// step: decode the cookie
if err != nil {
return token, err
}
return decodeText(token, r.config.EncryptionKey)
}
/*
Copyright 2015 All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
/*
Copyright 2015 All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"strings"
"github.com/gin-gonic/gin"
)
const (
// cxEnforce is the tag name for a request requiring
cxEnforce = "Enforcing"
)
//
// entryPointHandler checks to see if the request requires authentication
//
func (r oauthProxy) entryPointHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
if strings.HasPrefix(cx.Request.URL.Path, oauthURL) {
cx.Next()
return
}
// step: check if authentication is required - gin doesn't support wildcard url, so we have have to use prefixes
for _, resource := range r.config.Resources {
if strings.HasPrefix(cx.Request.URL.Path, resource.URL) {
if resource.WhiteListed {
break
}
// step: inject the resource into the context, saves us from doing this again
if containedIn("ANY", resource.Methods) || containedIn(cx.Request.Method, resource.Methods) {
cx.Set(cxEnforce, resource)
}
break
}
}
// step: pass into the authentication and admission handlers
cx.Next()
// step: add a custom headers to the request
for k, v := range r.config.Header {
cx.Request.Header.Set(k, v)
}
// step: check the request has not been aborted and if not, proxy request
if !cx.IsAborted() {
r.proxyHandler(cx)
}
}
}
/*
Copyright 2015 All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"testing"
"github.com/gin-gonic/gin"
)
func TestEntrypointHandlerSecure(t *testing.T) {
proxy := newFakeKeycloakProxyWithResources(t, []*Resource{
{
URL: "/admin/white_listed",
WhiteListed: true,
},
{
URL: "/admin",
Methods: []string{"ANY"},
},
{
URL: "/",
Methods: []string{"POST"},
Roles: []string{"test"},
},
})
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", "/")},
{Context: newFakeGinContext("GET", "/admin"), Secure: true},
{Context: newFakeGinContext("GET", "/admin/white_listed")},
{Context: newFakeGinContext("GET", "/admin/white"), Secure: true},
{Context: newFakeGinContext("GET", "/not_secure")},
{Context: newFakeGinContext("POST", "/"), Secure: true},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
func TestEntrypointMethods(t *testing.T) {
proxy := newFakeKeycloakProxyWithResources(t, []*Resource{
{
URL: "/u0",
Methods: []string{"GET", "POST"},
},
{
URL: "/u1",
Methods: []string{"ANY"},
},
{
URL: "/u2",
Methods: []string{"POST", "PUT"},
},
})
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", "/u0"), Secure: true},
{Context: newFakeGinContext("POST", "/u0"), Secure: true},
{Context: newFakeGinContext("PUT", "/u0"), Secure: false},
{Context: newFakeGinContext("GET", "/u1"), Secure: true},
{Context: newFakeGinContext("POST", "/u1"), Secure: true},
{Context: newFakeGinContext("PATCH", "/u1"), Secure: true},
{Context: newFakeGinContext("POST", "/u2"), Secure: true},
{Context: newFakeGinContext("PUT", "/u2"), Secure: true},
{Context: newFakeGinContext("DELETE", "/u2"), Secure: false},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
func TestEntrypointWhiteListing(t *testing.T) {
proxy := newFakeKeycloakProxyWithResources(t, []*Resource{
{
URL: "/admin/white_listed",
WhiteListed: true,
},
{
URL: "/admin",
Methods: []string{"ANY"},
},
})
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", "/")},
{Context: newFakeGinContext("GET", "/admin"), Secure: true},
{Context: newFakeGinContext("GET", "/admin/white_listed")},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
func TestEntrypointHandler(t *testing.T) {
proxy := newFakeKeycloakProxy(t)
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", fakeAdminRoleURL), Secure: true},
{Context: newFakeGinContext("GET", fakeAdminRoleURL+"/sso"), Secure: true},
{Context: newFakeGinContext("GET", fakeAdminRoleURL+"/../sso"), Secure: true},
{Context: newFakeGinContext("GET", "/not_secure")},
{Context: newFakeGinContext("GET", fakeTestWhitelistedURL)},
{Context: newFakeGinContext("GET", oauthURL)},
{Context: newFakeGinContext("GET", fakeTestListenOrdered), Secure: true},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
...@@ -82,80 +82,6 @@ func TestExpirationHandler(t *testing.T) { ...@@ -82,80 +82,6 @@ func TestExpirationHandler(t *testing.T) {
} }
} }
func TestCrossSiteHandler(t *testing.T) {
kc := newFakeKeycloakProxy(t)
handler := kc.crossSiteHandler()
cases := []struct {
Cors *CORS
Headers map[string]string
}{
{
Cors: &CORS{
Origins: []string{"*"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*",
},
},
{
Cors: &CORS{
Origins: []string{"*", "https://examples.com"},
Methods: []string{"GET"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*,https://examples.com",
"Access-Control-Allow-Methods": "GET",
},
},
}
for i, c := range cases {
// step: get the config
kc.config.CORS = c.Cors
// call the handler and check the responses
context := newFakeGinContext("GET", "/oauth/test")
handler(context)
// step: check the headers
for k, v := range c.Headers {
value := context.Writer.Header().Get(k)
if value == "" {
t.Errorf("case %d, should have had the %s header set, headers: %v", i, k, context.Writer.Header())
continue
}
if value != v {
t.Errorf("case %d, expected: %s but got %s", i, k, value)
}
}
}
}
func TestSecurityHandler(t *testing.T) {
kc := newFakeKeycloakProxy(t)
handler := kc.securityHandler()
context := newFakeGinContext("GET", "/")
handler(context)
if context.Writer.Status() != http.StatusOK {
t.Errorf("we should have received a 200")
}
kc = newFakeKeycloakProxy(t)
kc.config.Hostnames = []string{"127.0.0.1"}
handler = kc.securityHandler()
handler(context)
if context.Writer.Status() != http.StatusOK {
t.Errorf("we should have received a 200 not %d", context.Writer.Status())
}
kc = newFakeKeycloakProxy(t)
kc.config.Hostnames = []string{"127.0.0.2"}
handler = kc.securityHandler()
handler(context)
if context.Writer.Status() != http.StatusInternalServerError {
t.Errorf("we should have received a 500 not %d", context.Writer.Status())
}
}
func TestHealthHandler(t *testing.T) { func TestHealthHandler(t *testing.T) {
proxy := newFakeKeycloakProxy(t) proxy := newFakeKeycloakProxy(t)
context := newFakeGinContext("GET", healthURL) context := newFakeGinContext("GET", healthURL)
......
/*
Copyright 2015 All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"fmt"
"net/http"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/gin-gonic/gin"
"github.com/unrolled/secure"
)
//
// loggingHandler is a custom http logger
//
func (r *oauthProxy) loggingHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
start := time.Now()
cx.Next()
latency := time.Now().Sub(start)
log.WithFields(log.Fields{
"client_ip": cx.ClientIP(),
"method": cx.Request.Method,
"status": cx.Writer.Status(),
"bytes": cx.Writer.Size(),
"path": cx.Request.URL.Path,
"latency": latency.String(),
}).Infof("[%d] |%s| |%10v| %-5s %s", cx.Writer.Status(), cx.ClientIP(), latency, cx.Request.Method, cx.Request.URL.Path)
}
}
//
// securityHandler performs numerous security checks on the request
//
func (r *oauthProxy) securityHandler() gin.HandlerFunc {
// step: create the security options
secure := secure.New(secure.Options{
AllowedHosts: r.config.Hostnames,
BrowserXssFilter: true,
ContentTypeNosniff: true,
FrameDeny: true,
STSIncludeSubdomains: true,
STSSeconds: 31536000,
})
return func(cx *gin.Context) {
// step: pass through the security middleware
if err := secure.Process(cx.Writer, cx.Request); err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed security middleware")
cx.Abort()
return
}
// step: permit the request to continue
cx.Next()
}
}
//
// crossSiteHandler injects the CORS headers, if set, for request made to /oauth
//
func (r *oauthProxy) crossSiteHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
c := r.config.CORS
if len(c.Origins) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ","))
}
if len(c.Methods) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ","))
}
if len(c.Headers) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ","))
}
if len(c.ExposedHeaders) > 0 {
cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ","))
}
if c.Credentials {
cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if c.MaxAge > 0 {
cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds())))
}
}
}
//
// proxyHandler is responsible to proxy the requests on to the upstream endpoint
//
func (r *oauthProxy) proxyHandler(cx *gin.Context) {
// step: double check, if enforce is true and no user context it's a internal error
if _, found := cx.Get(cxEnforce); found {
if _, found := cx.Get(userContextName); !found {
log.Errorf("no user context found for a secure request")
cx.AbortWithStatus(http.StatusInternalServerError)
return
}
}
// step: retrieve the user context if any
if user, found := cx.Get(userContextName); found {
id := user.(*userContext)
cx.Request.Header.Add("X-Auth-UserId", id.id)
cx.Request.Header.Add("X-Auth-Subject", id.preferredName)
cx.Request.Header.Add("X-Auth-Username", id.name)
cx.Request.Header.Add("X-Auth-Email", id.email)
cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String())
cx.Request.Header.Add("X-Auth-Token", id.token.Encode())
cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ","))
}
// step: add the default headers
cx.Request.Header.Add("X-Forwarded-For", cx.Request.RemoteAddr)
cx.Request.Header.Set("X-Forwarded-Agent", prog)
cx.Request.Header.Set("X-Forwarded-Agent-Version", version)
// step: is this connection upgrading?
if isUpgradedConnection(cx.Request) {
log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade))
if err := tryUpdateConnection(cx, r.endpoint); err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection")
cx.AbortWithStatus(http.StatusInternalServerError)
return
}
cx.Abort()
return
}
r.upstream.ServeHTTP(cx.Writer, cx.Request)
}
//
// expirationHandler checks if the token has expired
//
func (r *oauthProxy) expirationHandler(cx *gin.Context) {
// step: get the access token from the request
user, err := getIdentity(cx)
if err != nil {
cx.AbortWithError(http.StatusUnauthorized, err)
return
}
// step: check the access is not expired
if user.isExpired() {
cx.AbortWithError(http.StatusUnauthorized, err)
return
}
cx.AbortWithStatus(http.StatusOK)
}
//
// tokenHandler display access token to screen
//
func (r *oauthProxy) tokenHandler(cx *gin.Context) {
// step: extract the access token from the request
user, err := getIdentity(cx)
if err != nil {
cx.AbortWithError(http.StatusBadRequest, fmt.Errorf("unable to retrieve session, error: %s", err))
return
}
// step: write the json content
cx.Writer.Header().Set("Content-Type", "application/json")
cx.String(http.StatusOK, fmt.Sprintf("%s", user.token.Payload))
}
//
// healthHandler is a health check handler for the service
//
func (r *oauthProxy) healthHandler(cx *gin.Context) {
cx.String(http.StatusOK, "OK")
}
...@@ -36,8 +36,8 @@ func main() { ...@@ -36,8 +36,8 @@ func main() {
kc.Action = func(cx *cli.Context) { kc.Action = func(cx *cli.Context) {
// step: do we have a configuration file? // step: do we have a configuration file?
if filename := cx.String("config"); filename != "" { if filename := cx.String("config"); filename != "" {
if err := readConfigFile(cx.String("config"), config); err != nil { if err := readConfigFile(filename, config); err != nil {
printUsage(err.Error()) printUsage(fmt.Sprintf("unable to read the configuration file: %s, error: %s", filename, err.Error()))
} }
} }
// step: parse the command line options // step: parse the command line options
......
...@@ -16,13 +16,289 @@ limitations under the License. ...@@ -16,13 +16,289 @@ limitations under the License.
package main package main
import ( import (
"fmt"
"regexp" "regexp"
"strings"
"time" "time"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/coreos/go-oidc/jose"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/unrolled/secure"
) )
const (
// cxEnforce is the tag name for a request requiring
cxEnforce = "Enforcing"
)
//
// loggingHandler is a custom http logger
//
func (r *oauthProxy) loggingHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
start := time.Now()
cx.Next()
latency := time.Now().Sub(start)
log.WithFields(log.Fields{
"client_ip": cx.ClientIP(),
"method": cx.Request.Method,
"status": cx.Writer.Status(),
"bytes": cx.Writer.Size(),
"path": cx.Request.URL.Path,
"latency": latency.String(),
}).Infof("[%d] |%s| |%10v| %-5s %s", cx.Writer.Status(), cx.ClientIP(), latency, cx.Request.Method, cx.Request.URL.Path)
}
}
//
// entryPointHandler checks to see if the request requires authentication
//
func (r oauthProxy) entryPointHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
if strings.HasPrefix(cx.Request.URL.Path, oauthURL) {
cx.Next()
return
}
// step: check if authentication is required - gin doesn't support wildcard url, so we have have to use prefixes
for _, resource := range r.config.Resources {
if strings.HasPrefix(cx.Request.URL.Path, resource.URL) {
if resource.WhiteListed {
break
}
// step: inject the resource into the context, saves us from doing this again
if containedIn("ANY", resource.Methods) || containedIn(cx.Request.Method, resource.Methods) {
cx.Set(cxEnforce, resource)
}
break
}
}
// step: pass into the authentication and admission handlers
cx.Next()
// step: add a custom headers to the request
for k, v := range r.config.Header {
cx.Request.Header.Set(k, v)
}
// step: check the request has not been aborted and if not, proxy request
if !cx.IsAborted() {
r.proxyHandler(cx)
}
}
}
//
// crossOriginResourceHandler injects the CORS headers, if set, for request made to /oauth
//
func (r *oauthProxy) crossOriginResourceHandler(c CORS) gin.HandlerFunc {
return func(cx *gin.Context) {
if len(c.Origins) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Origin", strings.Join(c.Origins, ","))
}
if len(c.Methods) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(c.Methods, ","))
}
if len(c.Headers) > 0 {
cx.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(c.Headers, ","))
}
if len(c.ExposedHeaders) > 0 {
cx.Writer.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ","))
}
if c.Credentials {
cx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if c.MaxAge > 0 {
cx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", int(c.MaxAge.Seconds())))
}
}
}
//
// securityHandler performs numerous security checks on the request
//
func (r *oauthProxy) securityHandler() gin.HandlerFunc {
// step: create the security options
secure := secure.New(secure.Options{
AllowedHosts: r.config.Hostnames,
BrowserXssFilter: true,
ContentTypeNosniff: true,
FrameDeny: true,
STSIncludeSubdomains: true,
STSSeconds: 31536000,
})
return func(cx *gin.Context) {
// step: pass through the security middleware
if err := secure.Process(cx.Writer, cx.Request); err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed security middleware")
cx.Abort()
return
}
// step: permit the request to continue
cx.Next()
}
}
//
// authenticationHandler is responsible for verifying the access token
//
func (r *oauthProxy) authenticationHandler() gin.HandlerFunc {
return func(cx *gin.Context) {
// step: is authentication required on this uri?
if _, found := cx.Get(cxEnforce); !found {
log.Debugf("skipping the authentication handler, resource not protected")
cx.Next()
return
}
// step: grab the user identity from the request
user, err := getIdentity(cx)
if err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("failed to get session, redirecting for authorization")
r.redirectToAuthorization(cx)
return
}
// step: inject the user into the context
cx.Set(userContextName, user)
// step: verify the access token
if r.config.SkipTokenVerification {
log.Warnf("skip token verification enabled, skipping verification process - FOR TESTING ONLY")
if user.isExpired() {
log.WithFields(log.Fields{
"username": user.name,
"expired_on": user.expiresAt.String(),
}).Errorf("the session has expired and verification switch off")
r.redirectToAuthorization(cx)
}
return
}
// step: verify the access token
if err := verifyToken(r.client, user.token); err != nil {
// step: if the error post verification is anything other than a token expired error
// we immediately throw an access forbidden - as there is something messed up in the token
if err != ErrAccessTokenExpired {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("verification of the access token failed")
r.accessForbidden(cx)
return
}
// step: are we refreshing the access tokens?
if !r.config.EnableRefreshTokens {
log.WithFields(log.Fields{
"email": user.name,
"expired_on": user.expiresAt.String(),
}).Errorf("the session has expired and access token refreshing is disabled")
r.redirectToAuthorization(cx)
return
}
// step: we do not refresh bearer token requests
if user.isBearer() {
log.WithFields(log.Fields{
"email": user.name,
"expired_on": user.expiresAt.String(),
}).Errorf("the session has expired and we are using bearer tokens")
r.redirectToAuthorization(cx)
return
}
log.WithFields(log.Fields{
"email": user.email,
"client_ip": cx.ClientIP(),
}).Infof("the accces token for user: %s has expired, attemping to refresh the token", user.email)
// step: check if the user has refresh token
rToken, err := r.retrieveRefreshToken(cx, user)
if err != nil {
log.WithFields(log.Fields{
"email": user.email,
"error": err.Error(),
}).Errorf("unable to find a refresh token for the client: %s", user.email)
r.redirectToAuthorization(cx)
return
}
log.WithFields(log.Fields{
"email": user.email,
}).Infof("found a refresh token, attempting to refresh access token for user: %s", user.email)
// step: attempts to refresh the access token
token, expires, err := refreshToken(r.client, rToken)
if err != nil {
// step: has the refresh token expired
switch err {
case ErrRefreshTokenExpired:
log.WithFields(log.Fields{"token": token}).Warningf("the refresh token has expired")
clearAllCookies(cx)
default:
log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to refresh the access token")
}
r.redirectToAuthorization(cx)
return
}
// step: inject the refreshed access token
log.WithFields(log.Fields{
"email": user.email,
"access_expires_in": expires.Sub(time.Now()).String(),
}).Infof("injecting refreshed access token, expires on: %s", expires.Format(time.RFC1123))
// step: clear the cookie up
dropAccessTokenCookie(cx, token, r.config.IdleDuration)
if r.useStore() {
go func(t jose.JWT, rt string) {
// step: the access token has been updated, we need to delete old reference and update the store
if err := r.DeleteRefreshToken(t); err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("unable to delete the old refresh tokem from store")
}
// step: store the new refresh token reference place the session in the store
if err := r.StoreRefreshToken(t, rt); err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("failed to place the refresh token in the store")
return
}
}(user.token, rToken)
} else {
// step: update the expiration on the refresh token
dropRefreshTokenCookie(cx, rToken, r.config.IdleDuration*2)
}
// step: update the with the new access token
user.token = token
// step: inject the user into the context
cx.Set(userContextName, user)
}
cx.Next()
}
}
// //
// admissionHandler is responsible checking the access token against the protected resource // admissionHandler is responsible checking the access token against the protected resource
// //
......
...@@ -24,6 +24,230 @@ import ( ...@@ -24,6 +24,230 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func TestEntrypointHandlerSecure(t *testing.T) {
proxy := newFakeKeycloakProxyWithResources(t, []*Resource{
{
URL: "/admin/white_listed",
WhiteListed: true,
},
{
URL: "/admin",
Methods: []string{"ANY"},
},
{
URL: "/",
Methods: []string{"POST"},
Roles: []string{"test"},
},
})
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", "/")},
{Context: newFakeGinContext("GET", "/admin"), Secure: true},
{Context: newFakeGinContext("GET", "/admin/white_listed")},
{Context: newFakeGinContext("GET", "/admin/white"), Secure: true},
{Context: newFakeGinContext("GET", "/not_secure")},
{Context: newFakeGinContext("POST", "/"), Secure: true},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
func TestEntrypointMethods(t *testing.T) {
proxy := newFakeKeycloakProxyWithResources(t, []*Resource{
{
URL: "/u0",
Methods: []string{"GET", "POST"},
},
{
URL: "/u1",
Methods: []string{"ANY"},
},
{
URL: "/u2",
Methods: []string{"POST", "PUT"},
},
})
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", "/u0"), Secure: true},
{Context: newFakeGinContext("POST", "/u0"), Secure: true},
{Context: newFakeGinContext("PUT", "/u0"), Secure: false},
{Context: newFakeGinContext("GET", "/u1"), Secure: true},
{Context: newFakeGinContext("POST", "/u1"), Secure: true},
{Context: newFakeGinContext("PATCH", "/u1"), Secure: true},
{Context: newFakeGinContext("POST", "/u2"), Secure: true},
{Context: newFakeGinContext("PUT", "/u2"), Secure: true},
{Context: newFakeGinContext("DELETE", "/u2"), Secure: false},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
func TestEntrypointWhiteListing(t *testing.T) {
proxy := newFakeKeycloakProxyWithResources(t, []*Resource{
{
URL: "/admin/white_listed",
WhiteListed: true,
},
{
URL: "/admin",
Methods: []string{"ANY"},
},
})
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", "/")},
{Context: newFakeGinContext("GET", "/admin"), Secure: true},
{Context: newFakeGinContext("GET", "/admin/white_listed")},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
func TestEntrypointHandler(t *testing.T) {
proxy := newFakeKeycloakProxy(t)
handler := proxy.entryPointHandler()
tests := []struct {
Context *gin.Context
Secure bool
}{
{Context: newFakeGinContext("GET", fakeAdminRoleURL), Secure: true},
{Context: newFakeGinContext("GET", fakeAdminRoleURL+"/sso"), Secure: true},
{Context: newFakeGinContext("GET", fakeAdminRoleURL+"/../sso"), Secure: true},
{Context: newFakeGinContext("GET", "/not_secure")},
{Context: newFakeGinContext("GET", fakeTestWhitelistedURL)},
{Context: newFakeGinContext("GET", oauthURL)},
{Context: newFakeGinContext("GET", fakeTestListenOrdered), Secure: true},
}
for i, c := range tests {
handler(c.Context)
_, found := c.Context.Get(cxEnforce)
if c.Secure && !found {
t.Errorf("test case %d should have been set secure", i)
}
if !c.Secure && found {
t.Errorf("test case %d should not have been set secure", i)
}
}
}
func TestSecurityHandler(t *testing.T) {
kc := newFakeKeycloakProxy(t)
handler := kc.securityHandler()
context := newFakeGinContext("GET", "/")
handler(context)
if context.Writer.Status() != http.StatusOK {
t.Errorf("we should have received a 200")
}
kc = newFakeKeycloakProxy(t)
kc.config.Hostnames = []string{"127.0.0.1"}
handler = kc.securityHandler()
handler(context)
if context.Writer.Status() != http.StatusOK {
t.Errorf("we should have received a 200 not %d", context.Writer.Status())
}
kc = newFakeKeycloakProxy(t)
kc.config.Hostnames = []string{"127.0.0.2"}
handler = kc.securityHandler()
handler(context)
if context.Writer.Status() != http.StatusInternalServerError {
t.Errorf("we should have received a 500 not %d", context.Writer.Status())
}
}
func TestCrossSiteHandler(t *testing.T) {
proxy := newFakeKeycloakProxy(t)
cases := []struct {
Cors CORS
Headers map[string]string
}{
{
Cors: CORS{
Origins: []string{"*"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*",
},
},
{
Cors: CORS{
Origins: []string{"*", "https://examples.com"},
Methods: []string{"GET"},
},
Headers: map[string]string{
"Access-Control-Allow-Origin": "*,https://examples.com",
"Access-Control-Allow-Methods": "GET",
},
},
}
for i, c := range cases {
handler := proxy.crossOriginResourceHandler(c.Cors)
// call the handler and check the responses
context := newFakeGinContext("GET", "/oauth/test")
handler(context)
// step: check the headers
for k, v := range c.Headers {
value := context.Writer.Header().Get(k)
if value == "" {
t.Errorf("case %d, should have had the %s header set, headers: %v", i, k, context.Writer.Header())
continue
}
if value != v {
t.Errorf("case %d, expected: %s but got %s", i, k, value)
}
}
}
}
func TestAdmissionHandlerRoles(t *testing.T) { func TestAdmissionHandlerRoles(t *testing.T) {
proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ proxy := newFakeKeycloakProxyWithResources(t, []*Resource{
{ {
......
/*
Copyright 2015 All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"net/http"
"testing"
"github.com/gin-gonic/gin"
)
type fakeOAuthServer struct {
}
type fakeDiscoveryResponse struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
EndSessionEndpoint string `json:"end_session_endpoint"`
GrantTypesSupported []string `json:"grant_types_supported"`
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
Issuer string `json:"issuer"`
JwksURI string `json:"jwks_uri"`
RegistrationEndpoint string `json:"registration_endpoint"`
ResponseModesSupported []string `json:"response_modes_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
SubjectTypesSupported []string `json:"subject_types_supported"`
TokenEndpoint string `json:"token_endpoint"`
TokenIntrospectionEndpoint string `json:"token_introspection_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
}
type fakeKeysResponse struct {
Keys []fakeKeyResponse `json:"keys"`
}
type fakeKeyResponse struct {
Alg string `json:"alg"`
E string `json:"e"`
Kid string `json:"kid"`
Kty string `json:"kty"`
N string `json:"n"`
Use string `json:"use"`
}
const (
fakePublicKey = "ibGNjo_opyEGbeDP3cctILhSW-sGKtG67hCZXxvHx-wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5-FMbHth-TKZiEhm-3EBadc1qgkfnpinfpxCVqHHaF8mFLC5-k3JsINIR0FAmPN9trxryI_npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh_rrLKAs0AdUYwXGAslnYDBACiR8GNrb7Q"
oauthPublicKey = "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQAB"
oauthPrivateKey = "MIIEowIBAAKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQABAoIBAGtfMlSmMbUKErpiIZX+uFkYgti8p92CGLOF7CN3RU3H+PgfF1m4xHGqt4xw+2JyhgQFgTY4IiIN1QuPFzI82+6jDvMqBwEi2e0TGj4RKiOX9D8b/qSL9eUSfqQKPqnPZfBymM3sqe5yddY7KVZiMXEEBu1efhhTADluIraKQjYJKgQd0P3CgfqhuUWgCqGjPwIg0BkzXofR0bjdrq8d0ul8JLnT+9ho/x8rahEN/LTHHLIwb6IYUj8X10tDZWPDk2NE5wRIy18peSXYNTeGhY1ThF75ZOAH5c1qgi0ObE+dUSqzwcDWqNDPxFvg2x67KbcMaTO6u87/mGJfuO2ekz0CgYEAwpR+tZdafTzR+MLGg55mxsfVjAWGNxp0AMwWZVTpPx1I+VgdLsMkUY8LpY2Zt8l2yInIGEzYRBNFYPrM73bW5v0bleGl60I6j3KA/Ic6RUaweycbQgMxob5PCWrMm94Jib1bGAxNU1m0Jp9rzxGUzWw3TpSw6LHNLqokwMCKG/cCgYEAtSg1oqeCvvCrIdA6AulzzWR6x2Re/Iv8MYJ5X0fNPRBHSVhwsdb2nLfjMPmLesBOPm55O/LZDFtL8unpOUc+qT8QWKAjvI0/HtYf2sec3sP/dxCYYK18grK1cvD/UAUfiljM0gAsxZRT77VbpOIMCOi9YjHoyeRgCQtxB9CuZjsCgYEAsLNfehLvpwmjeK+QzRf9J4l0AQtHPiU0sUClGfKJOrqieWUuYzftdG9d2UMFFGTNDQIqhv7J6tBBUfeQQep+8BdshKj9Hu7u9TO7tRgsr5qpS71QwJrb6JFFfzzQgL+bk800u1r4obe1pNljcxD5O6+JbkATg81rknQKmkx/XzMCgYARnyqwesjuF+0dqeqqs9jO5vJGiQ3wVRGgI0f5K7vcL8Qvb0nvErEEh6Ky9eNKeoBh9E8YtMPGPu9BXt2P8801m2vUoyc2xSqZrkyE9Jve04P7KgMYjGerMwURfD3po8XwqDisSNYSFh6gF60ledOf3jvl3GL/mJZ66sEA+JyuVwKBgGwef1FWkDTeft6VFo2obHCh8Fc8rsV2rQ0twgmA00nmuckKr5MQgyMiz2YYWanmOS18xLgl7FzvyX56clj1MvRl9xnwhSudtE4fxg6R4rzwf3jaWtAkXEHet+mqVRJgI9m5Bn8E7nVVmjgRlogZsgYq2pF3nL1sgl3ti7gOVVL6"
oauthCertificate = "MIICnzCCAYcCBgFUPZAJhjANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDDAhob2QtdGVzdDAeFw0xNjA0MjIxMDQwMzBaFw0yNjA0MjIxMDQyMTBaMBMxETAPBgNVBAMMCGhvZC10ZXN0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBFd9T/1s769tGhOMtUspP2tChKy5OWF50HkRVLny1nt12JeQUvuVSD3l7vN17hFRpMm1ktjVCTxBk5PRfPtpOcMCG2zgYbB73hIRYZaKG5X6/r2y3TllZ2UkZh0ndL+jrn1L4I2zxB5OAi3CDTxiFtjcEShAC9smjp04Omxwat53k8IxJLRgnpuC/TMbxUPHLNjuOHLLFeSN7095SuD+qzx0H7fT4sqW3+mAr7Q/kl2yq4vMXfLHt5KkOm7O5px5mRoGS4Asbkw5MQMgP618uQ9k7EQZx37jF2ol4Z7uLQWscePdWA66ajbxAtybCesNPa4uUrb1YVdx6MikWyZ0i7"
)
func newFakeOAuthServer(t *testing.T) {
s := new(fakeOAuthServer)
r := gin.New()
r.GET("/auth/realms/hod-test/.well-known/openid-configuration", s.discoveryHandler)
r.GET("/auth/realms/hod-test/protocol/openid-connect/certs", s.keysHandler)
r.POST("/auth/realms/hod-test/protocol/openid-connect/token", s.tokenHandler)
r.POST("/auth/realms/hod-test/protocol/openid-connect/auth", s.authHandler)
if err := r.Run("127.0.0.1:8080"); err != nil {
t.Fatalf("failed to start the fake oauth service, error: %s", err)
}
}
func (r fakeOAuthServer) discoveryHandler(cx *gin.Context) {
cx.JSON(http.StatusOK, fakeDiscoveryResponse{
IDTokenSigningAlgValuesSupported: []string{"RS256"},
Issuer: "http://127.0.0.1:8080/auth/realms/hod-test",
AuthorizationEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/auth",
TokenEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/token",
RegistrationEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/clients-registrations/openid-connect",
TokenIntrospectionEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/token/introspect",
UserinfoEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/userinfo",
EndSessionEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/logout",
JwksURI: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/certs",
GrantTypesSupported: []string{"authorization_code", "implicit", "refresh_token", "password", "client_credentials"},
ResponseModesSupported: []string{"query", "fragment", "form_post"},
ResponseTypesSupported: []string{"code", "none", "id_token", "token", "id_token token", "code id_token", "code token", "code id_token token"},
SubjectTypesSupported: []string{"public"},
})
}
func (r fakeOAuthServer) keysHandler(cx *gin.Context) {
cx.JSON(http.StatusOK, fakeKeysResponse{
Keys: []fakeKeyResponse{
{
Kid: "ing3Hnuj0ciqrHCOxt__-B53jzXcdD1n1iKbX3GsD9s",
Kty: "RSA",
Alg: "RS256",
Use: "sig",
N: fakePublicKey,
E: "AQAB",
},
},
})
}
func (r fakeOAuthServer) authHandler(cx *gin.Context) {
}
func (r fakeOAuthServer) tokenHandler(cx *gin.Context) {
}
...@@ -50,7 +50,7 @@ type oauthProxy struct { ...@@ -50,7 +50,7 @@ type oauthProxy struct {
// the upstream endpoint url // the upstream endpoint url
endpoint *url.URL endpoint *url.URL
// the store interface // the store interface
store Store store storage
} }
type reverseProxy interface { type reverseProxy interface {
...@@ -75,7 +75,7 @@ func newProxy(cfg *Config) (*oauthProxy, error) { ...@@ -75,7 +75,7 @@ func newProxy(cfg *Config) (*oauthProxy, error) {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} }
log.Infof("starting %s, version: %s, author: %s", prog, version, author) log.Infof("starting %s, author: %s, version: %s, ", prog, author, version)
service := &oauthProxy{config: cfg} service := &oauthProxy{config: cfg}
...@@ -87,7 +87,7 @@ func newProxy(cfg *Config) (*oauthProxy, error) { ...@@ -87,7 +87,7 @@ func newProxy(cfg *Config) (*oauthProxy, error) {
// step: initialize the store if any // step: initialize the store if any
if cfg.StoreURL != "" { if cfg.StoreURL != "" {
if service.store, err = newStore(cfg.StoreURL); err != nil { if service.store, err = newStorage(cfg.StoreURL); err != nil {
return nil, err return nil, err
} }
} }
...@@ -223,17 +223,33 @@ func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) { ...@@ -223,17 +223,33 @@ func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) {
// setupReverseProxy create a reverse http proxy from the upstream // setupReverseProxy create a reverse http proxy from the upstream
// //
func (r *oauthProxy) setupReverseProxy(upstream *url.URL) (reverseProxy, error) { func (r *oauthProxy) setupReverseProxy(upstream *url.URL) (reverseProxy, error) {
proxy := httputil.NewSingleHostReverseProxy(upstream) // step: create the default dialer
proxy.Transport = &http.Transport{ dialer := (&net.Dialer{
Dial: (&net.Dialer{
KeepAlive: 10 * time.Second, KeepAlive: 10 * time.Second,
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
}).Dial, }).Dial
// step: are we using a unix socket?
if upstream.Scheme == "unix" {
log.Infof("using the unix domain socket: %s for upstream", upstream.Host)
socketPath := upstream.Host
dialer = func(network, address string) (net.Conn, error) {
return net.Dial("unix", socketPath)
}
upstream.Path = ""
upstream.Host = "domain-sock"
upstream.Scheme = "http"
}
// step: create the reverse proxy
proxy := httputil.NewSingleHostReverseProxy(upstream)
// step: customize the http transport
proxy.Transport = &http.Transport{
Dial: dialer,
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: r.config.SkipUpstreamTLSVerify, InsecureSkipVerify: r.config.SkipUpstreamTLSVerify,
}, },
DisableKeepAlives: !r.config.Keepalives, DisableKeepAlives: !r.config.Keepalives,
TLSHandshakeTimeout: 10 * time.Second,
} }
return proxy, nil return proxy, nil
...@@ -253,7 +269,7 @@ func (r oauthProxy) setupRouter() error { ...@@ -253,7 +269,7 @@ func (r oauthProxy) setupRouter() error {
r.router.Use(r.securityHandler()) r.router.Use(r.securityHandler())
} }
// step: add the routing // step: add the routing
oauth := r.router.Group(oauthURL).Use(r.crossSiteHandler()) oauth := r.router.Group(oauthURL).Use(r.crossOriginResourceHandler(r.config.CrossOrigin))
{ {
oauth.GET(authorizationURL, r.oauthAuthorizationHandler) oauth.GET(authorizationURL, r.oauthAuthorizationHandler)
oauth.GET(callbackURL, r.oauthCallbackHandler) oauth.GET(callbackURL, r.oauthCallbackHandler)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment