From a97951a95a8aa13abc82b262808c60b34ace8bfb Mon Sep 17 00:00:00 2001 From: Rohith <gambol99@gmail.com> Date: Tue, 10 May 2016 12:16:04 +0100 Subject: [PATCH] - adding a gitsha to the version - removing the needs to client secret as a configuration optoin - cleaning up the logging messages - adding the X-Forwarded-Host header to the upstream - cleaning up some of the method names - fixed the path bug in the unix domain upstream socket - added the oxy package to the godeps - updating the godeps for go-oidc package - removing the default settings for redirection url and upstream url - adding the proxy protocol support - updating the git+sha to include the dirty flag - updated the godeps for cli package - using the cli exiterror for errors - added the options to set the upstream timeout and keepalive options - cleaning up methods names - adding the proxy protocol to godeps - adding the goproxy package to godeps - switching to using the goproxy package for reverse proxying - moved the proxy header into a seperate middleware handler - adding the mocking of a oauth service - adding the forwarding proxy - changing the state parameter to base64 encoding to permit extended urls - added the forwarding agent support - updating the readme --- .travis.yml | 1 - CHANGELOG.md | 26 +- Godeps/Godeps.json | 25 +- .../github.com/armon/go-proxyproto/.gitignore | 2 + .../github.com/armon/go-proxyproto/LICENSE | 21 + .../github.com/armon/go-proxyproto/README.md | 36 ++ .../armon/go-proxyproto/protocol.go | 194 +++++++++ .../src/github.com/codegangsta/cli/.gitignore | 1 + .../github.com/codegangsta/cli/.travis.yml | 18 +- .../github.com/codegangsta/cli/CHANGELOG.md | 315 +++++++++++++++ .../src/github.com/codegangsta/cli/README.md | 303 +++++++++++--- .../src/github.com/codegangsta/cli/app.go | 207 ++++++++-- .../github.com/codegangsta/cli/category.go | 14 + .../src/github.com/codegangsta/cli/cli.go | 23 +- .../src/github.com/codegangsta/cli/command.go | 61 +-- .../src/github.com/codegangsta/cli/context.go | 113 ++++-- .../src/github.com/codegangsta/cli/errors.go | 92 +++++ .../src/github.com/codegangsta/cli/flag.go | 199 ++++++++-- .../src/github.com/codegangsta/cli/funcs.go | 28 ++ .../src/github.com/codegangsta/cli/help.go | 107 ++--- .../src/github.com/codegangsta/cli/runtests | 95 +++++ .../coreos/go-oidc/oidc/provider.go | 2 +- .../coreos/go-oidc/oidc/transport.go | 9 + .../src/github.com/elazarl/goproxy/.gitignore | 2 + .../src/github.com/elazarl/goproxy/LICENSE | 27 ++ .../src/github.com/elazarl/goproxy/README.md | 118 ++++++ .../src/github.com/elazarl/goproxy/actions.go | 57 +++ .../src/github.com/elazarl/goproxy/all.bash | 15 + .../src/github.com/elazarl/goproxy/ca.pem | 15 + .../src/github.com/elazarl/goproxy/certs.go | 56 +++ .../src/github.com/elazarl/goproxy/chunked.go | 59 +++ .../elazarl/goproxy/counterecryptor.go | 68 ++++ .../src/github.com/elazarl/goproxy/ctx.go | 87 ++++ .../github.com/elazarl/goproxy/dispatcher.go | 325 +++++++++++++++ .../src/github.com/elazarl/goproxy/doc.go | 100 +++++ .../src/github.com/elazarl/goproxy/https.go | 370 ++++++++++++++++++ .../src/github.com/elazarl/goproxy/key.pem | 15 + .../src/github.com/elazarl/goproxy/proxy.go | 162 ++++++++ .../github.com/elazarl/goproxy/responses.go | 38 ++ .../src/github.com/elazarl/goproxy/signer.go | 87 ++++ .../src/github.com/vulcand/oxy/LICENSE | 202 ++++++++++ Makefile | 12 +- README.md | 190 ++++++--- config.go | 196 +++++++--- config_sample.yml | 3 +- cookies_test.go | 2 +- doc.go | 107 +++-- forwarding.go | 233 +++++++++++ handlers.go | 123 ++---- handlers_test.go | 108 ++++- main.go | 20 +- middleware.go | 80 ++-- middleware_test.go | 101 ++++- oauth.go | 4 +- oauth_test.go | 219 ++++++++--- server.go | 323 +++++++++------ server_test.go | 44 ++- session_test.go | 2 +- stores.go | 4 +- tests/gen_token.go | 2 +- util_test.go | 36 +- utils.go | 125 ++++-- 62 files changed, 4810 insertions(+), 819 deletions(-) create mode 100644 Godeps/_workspace/src/github.com/armon/go-proxyproto/.gitignore create mode 100644 Godeps/_workspace/src/github.com/armon/go-proxyproto/LICENSE create mode 100644 Godeps/_workspace/src/github.com/armon/go-proxyproto/README.md create mode 100644 Godeps/_workspace/src/github.com/armon/go-proxyproto/protocol.go create mode 100644 Godeps/_workspace/src/github.com/codegangsta/cli/.gitignore create mode 100644 Godeps/_workspace/src/github.com/codegangsta/cli/CHANGELOG.md create mode 100644 Godeps/_workspace/src/github.com/codegangsta/cli/errors.go create mode 100644 Godeps/_workspace/src/github.com/codegangsta/cli/funcs.go create mode 100644 Godeps/_workspace/src/github.com/codegangsta/cli/runtests create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/.gitignore create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/LICENSE create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/README.md create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/actions.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/all.bash create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/ca.pem create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/certs.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/chunked.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/counterecryptor.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/ctx.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/dispatcher.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/doc.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/https.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/key.pem create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/proxy.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/responses.go create mode 100644 Godeps/_workspace/src/github.com/elazarl/goproxy/signer.go create mode 100644 Godeps/_workspace/src/github.com/vulcand/oxy/LICENSE create mode 100644 forwarding.go diff --git a/.travis.yml b/.travis.yml index 12f8ece..84c8e12 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,5 +21,4 @@ install: script: - make test - docker login -u ${REGISTRY_USERNAME} -p ${REGISTRY_TOKEN} -e ${AUTHOR_EMAIL} ${REGISTRY} - - if [ "$TRAVIS_BRANCH" == "master" ]; then VERSION=latest make docker-release; fi - if [ -n "$TRAVIS_TAG" ]; then VERSION=$TRAVIS_TAG make docker-release; fi diff --git a/CHANGELOG.md b/CHANGELOG.md index af5aa4c..2cc24d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,35 @@ +#### **1.1.0 (unreleased)** + +FIXES: + * Added a auto build to quay.io on the travis build for master and tags + * Fixed the host header to proxy to upstreams outside of the proxy domain (https://github.com/golang/go/issues/7618) + * Adding a git+sha to the usage + * Defaulting to gin mode release unless verbose is true + * Removed the gin debug logging for tests and builds + * Removed the default upstream, as it caught people by surprise and some accidentally forwarded to themselves + * Changed the state parameter (which is used as a redirect) to base64 the value allowing you to use complex urls + + +FEATURES: + * Adding environment variables to some of the command line options + * Adding the option of a forwarding agent, i.e. you can seat the proxy front of your application, + login to keycloak and use the proxy as forwarding agent to sign outbound requests. + * Adding the version information into a header on /oauth/health endpoint + * Removed the need to specify a client-secret, which means to cope with authz only or public endpoints + * Added role url tokenizer, /auth/%role%/ will extract the role element and check the token as it + * Added proxy protocol support for the listening socket (--enable-proxy-protocol=true) + +BREAKING CHANGES: + * Changed the X-Auth-Subject, it not is the actual subject from the token (makes more sense). + X-Auth-UserID will either be the subject id or the preferred username + #### **1.0.6 (May 6th, 2016)** FIXES: * Fixed the logout endpoint, ensuring users sessions are revoked. Note: i've not really tested this against Keycloak and Google. Revocation or logouts seems to have somewhat scattered implementation across providers. - #### **1.0.5 (May 3th, 2016)** FEATURES: diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 4279b32..df7fc93 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -1,13 +1,17 @@ { "ImportPath": "github.com/gambol99/keycloak-proxy", "GoVersion": "go1.5", - "GodepVersion": "v62", + "GodepVersion": "v70", "Deps": [ { "ImportPath": "github.com/Sirupsen/logrus", "Comment": "v0.10.0-14-g081307d", "Rev": "081307d9bc1364753142d5962fc1d795c742baaf" }, + { + "ImportPath": "github.com/armon/go-proxyproto", + "Rev": "609d6338d3a76ec26ac3fe7045a164d9a58436e7" + }, { "ImportPath": "github.com/boltdb/bolt", "Comment": "v1.2.0-13-g144418e", @@ -15,28 +19,28 @@ }, { "ImportPath": "github.com/codegangsta/cli", - "Comment": "1.2.0-237-g71f57d3", - "Rev": "71f57d300dd6a780ac1856c005c4b518cfd498ec" + "Comment": "v1.17.0-21-g0eb4e0b", + "Rev": "0eb4e0be6c214f8904ef6989b11072c7b897c657" }, { "ImportPath": "github.com/coreos/go-oidc/http", - "Rev": "a443fa229e11fd2d5da13c76763b29c447c451b0" + "Rev": "e6174c764e906bd60c76fdfc33faf5e0bdc875d6" }, { "ImportPath": "github.com/coreos/go-oidc/jose", - "Rev": "a443fa229e11fd2d5da13c76763b29c447c451b0" + "Rev": "e6174c764e906bd60c76fdfc33faf5e0bdc875d6" }, { "ImportPath": "github.com/coreos/go-oidc/key", - "Rev": "a443fa229e11fd2d5da13c76763b29c447c451b0" + "Rev": "e6174c764e906bd60c76fdfc33faf5e0bdc875d6" }, { "ImportPath": "github.com/coreos/go-oidc/oauth2", - "Rev": "a443fa229e11fd2d5da13c76763b29c447c451b0" + "Rev": "e6174c764e906bd60c76fdfc33faf5e0bdc875d6" }, { "ImportPath": "github.com/coreos/go-oidc/oidc", - "Rev": "a443fa229e11fd2d5da13c76763b29c447c451b0" + "Rev": "e6174c764e906bd60c76fdfc33faf5e0bdc875d6" }, { "ImportPath": "github.com/coreos/go-systemd/journal", @@ -63,6 +67,11 @@ "ImportPath": "github.com/davecgh/go-spew/spew", "Rev": "5215b55f46b2b919f50a1df0eaa5886afe4e3b3d" }, + { + "ImportPath": "github.com/elazarl/goproxy", + "Comment": "v1.0-80-g970f4ed", + "Rev": "970f4ed8995ab98f808e4abf06f52660aeaec7a9" + }, { "ImportPath": "github.com/gin-gonic/gin", "Comment": "v1.0rc1-148-g52fcc5d", diff --git a/Godeps/_workspace/src/github.com/armon/go-proxyproto/.gitignore b/Godeps/_workspace/src/github.com/armon/go-proxyproto/.gitignore new file mode 100644 index 0000000..dd2440d --- /dev/null +++ b/Godeps/_workspace/src/github.com/armon/go-proxyproto/.gitignore @@ -0,0 +1,2 @@ +*.test +*~ diff --git a/Godeps/_workspace/src/github.com/armon/go-proxyproto/LICENSE b/Godeps/_workspace/src/github.com/armon/go-proxyproto/LICENSE new file mode 100644 index 0000000..3ed5f43 --- /dev/null +++ b/Godeps/_workspace/src/github.com/armon/go-proxyproto/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Armon Dadgar + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/armon/go-proxyproto/README.md b/Godeps/_workspace/src/github.com/armon/go-proxyproto/README.md new file mode 100644 index 0000000..25a779c --- /dev/null +++ b/Godeps/_workspace/src/github.com/armon/go-proxyproto/README.md @@ -0,0 +1,36 @@ +# proxyproto + +This library provides the `proxyproto` package which can be used for servers +listening behind HAProxy of Amazon ELB load balancers. Those load balancers +support the use of a proxy protocol (http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt), +which provides a simple mechansim for the server to get the address of the client +instead of the load balancer. + +This library provides both a net.Listener and net.Conn implementation that +can be used to handle situation in which you may be using the proxy protocol. +Only proxy protocol version 1, the human-readable form, is understood. + +The only caveat is that we check for the "PROXY " prefix to determine if the protocol +is being used. If that string may occur as part of your input, then it is ambiguous +if the protocol is being used and you may have problems. + +# Documentation + +Full documentation can be found [here](http://godoc.org/github.com/armon/go-proxyproto). + +# Examples + +Using the library is very simple: + +``` + +// Create a listener +list, err := net.Listen("tcp", "...") + +// Wrap listener in a proxyproto listener +proxyList := &proxyproto.Listener{list} +conn, err :=proxyList.Accept() + +... +``` + diff --git a/Godeps/_workspace/src/github.com/armon/go-proxyproto/protocol.go b/Godeps/_workspace/src/github.com/armon/go-proxyproto/protocol.go new file mode 100644 index 0000000..2fc1dfc --- /dev/null +++ b/Godeps/_workspace/src/github.com/armon/go-proxyproto/protocol.go @@ -0,0 +1,194 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "fmt" + "io" + "log" + "net" + "strconv" + "strings" + "sync" + "time" +) + +var ( + // prefix is the string we look for at the start of a connection + // to check if this connection is using the proxy protocol + prefix = []byte("PROXY ") + prefixLen = len(prefix) +) + +// Listener is used to wrap an underlying listener, +// whose connections may be using the HAProxy Proxy Protocol (version 1). +// If the connection is using the protocol, the RemoteAddr() will return +// the correct client address. +type Listener struct { + Listener net.Listener +} + +// Conn is used to wrap and underlying connection which +// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will +// return the address of the client instead of the proxy address. +type Conn struct { + bufReader *bufio.Reader + conn net.Conn + dstAddr *net.TCPAddr + srcAddr *net.TCPAddr + once sync.Once +} + +// Accept waits for and returns the next connection to the listener. +func (p *Listener) Accept() (net.Conn, error) { + // Get the underlying connection + conn, err := p.Listener.Accept() + if err != nil { + return nil, err + } + return NewConn(conn), nil +} + +// Close closes the underlying listener. +func (p *Listener) Close() error { + return p.Listener.Close() +} + +// Addr returns the underlying listener's network address. +func (p *Listener) Addr() net.Addr { + return p.Listener.Addr() +} + +// NewConn is used to wrap a net.Conn that may be speaking +// the proxy protocol into a proxyproto.Conn +func NewConn(conn net.Conn) *Conn { + pConn := &Conn{ + bufReader: bufio.NewReader(conn), + conn: conn, + } + return pConn +} + +// Read is check for the proxy protocol header when doing +// the initial scan. If there is an error parsing the header, +// it is returned and the socket is closed. +func (p *Conn) Read(b []byte) (int, error) { + var err error + p.once.Do(func() { err = p.checkPrefix() }) + if err != nil { + return 0, err + } + return p.bufReader.Read(b) +} + +func (p *Conn) Write(b []byte) (int, error) { + return p.conn.Write(b) +} + +func (p *Conn) Close() error { + return p.conn.Close() +} + +func (p *Conn) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// RemoteAddr returns the address of the client if the proxy +// protocol is being used, otherwise just returns the address of +// the socket peer. If there is an error parsing the header, the +// address of the client is not returned, and the socket is closed. +// Once implication of this is that the call could block if the +// client is slow. Using a Deadline is recommended if this is called +// before Read() +func (p *Conn) RemoteAddr() net.Addr { + p.once.Do(func() { + if err := p.checkPrefix(); err != nil && err != io.EOF { + log.Printf("[ERR] Failed to read proxy prefix: %v", err) + } + }) + if p.srcAddr != nil { + return p.srcAddr + } + return p.conn.RemoteAddr() +} + +func (p *Conn) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +func (p *Conn) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +func (p *Conn) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} + +func (p *Conn) checkPrefix() error { + // Incrementally check each byte of the prefix + for i := 1; i <= prefixLen; i++ { + inp, err := p.bufReader.Peek(i) + if err != nil { + return err + } + + // Check for a prefix mis-match, quit early + if !bytes.Equal(inp, prefix[:i]) { + return nil + } + } + + // Read the header line + header, err := p.bufReader.ReadString('\n') + if err != nil { + p.conn.Close() + return err + } + + // Strip the carriage return and new line + header = header[:len(header)-2] + + // Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>) + parts := strings.Split(header, " ") + if len(parts) != 6 { + p.conn.Close() + return fmt.Errorf("Invalid header line: %s", header) + } + + // Verify the type is known + switch parts[1] { + case "TCP4": + case "TCP6": + default: + p.conn.Close() + return fmt.Errorf("Unhandled address type: %s", parts[1]) + } + + // Parse out the source address + ip := net.ParseIP(parts[2]) + if ip == nil { + p.conn.Close() + return fmt.Errorf("Invalid source ip: %s", parts[2]) + } + port, err := strconv.Atoi(parts[4]) + if err != nil { + p.conn.Close() + return fmt.Errorf("Invalid source port: %s", parts[4]) + } + p.srcAddr = &net.TCPAddr{IP: ip, Port: port} + + // Parse out the destination address + ip = net.ParseIP(parts[3]) + if ip == nil { + p.conn.Close() + return fmt.Errorf("Invalid destination ip: %s", parts[3]) + } + port, err = strconv.Atoi(parts[5]) + if err != nil { + p.conn.Close() + return fmt.Errorf("Invalid destination port: %s", parts[5]) + } + p.dstAddr = &net.TCPAddr{IP: ip, Port: port} + + return nil +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/.gitignore b/Godeps/_workspace/src/github.com/codegangsta/cli/.gitignore new file mode 100644 index 0000000..7823778 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/.gitignore @@ -0,0 +1 @@ +*.coverprofile diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml b/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml index c2b5c8d..b117165 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml @@ -1,18 +1,24 @@ language: go + sudo: false go: - 1.1.2 - 1.2.2 - 1.3.3 -- 1.4.2 -- 1.5.1 -- tip +- 1.4 +- 1.5.4 +- 1.6.2 +- master matrix: allow_failures: - - go: tip + - go: master + +before_script: +- go get github.com/meatballhat/gfmxr/... script: -- go vet ./... -- go test -v ./... +- ./runtests vet +- ./runtests test +- ./runtests gfmxr diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/CHANGELOG.md b/Godeps/_workspace/src/github.com/codegangsta/cli/CHANGELOG.md new file mode 100644 index 0000000..87a3ed2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/CHANGELOG.md @@ -0,0 +1,315 @@ +# Change Log + +**ATTN**: This project uses [semantic versioning](http://semver.org/). + +## [Unreleased] +### Added +- `./runtests` test runner with coverage tracking by default + +### Fixed +- Printing of command aliases in help text + +## [1.17.0] - 2016-05-09 +### Added +- Pluggable flag-level help text rendering via `cli.DefaultFlagStringFunc` +- `context.GlobalBoolT` was added as an analogue to `context.GlobalBool` +- Support for hiding commands by setting `Hidden: true` -- this will hide the + commands in help output + +### Changed +- `Float64Flag`, `IntFlag`, and `DurationFlag` default values are no longer + quoted in help text output. +- All flag types now include `(default: {value})` strings following usage when a + default value can be (reasonably) detected. +- `IntSliceFlag` and `StringSliceFlag` usage strings are now more consistent + with non-slice flag types +- Apps now exit with a code of 3 if an unknown subcommand is specified + (previously they printed "No help topic for...", but still exited 0. This + makes it easier to script around apps built using `cli` since they can trust + that a 0 exit code indicated a successful execution. +- cleanups based on [Go Report Card + feedback](https://goreportcard.com/report/github.com/codegangsta/cli) + +## [1.16.0] - 2016-05-02 +### Added +- `Hidden` field on all flag struct types to omit from generated help text + +### Changed +- `BashCompletionFlag` (`--enable-bash-completion`) is now omitted from +generated help text via the `Hidden` field + +### Fixed +- handling of error values in `HandleAction` and `HandleExitCoder` + +## [1.15.0] - 2016-04-30 +### Added +- This file! +- Support for placeholders in flag usage strings +- `App.Metadata` map for arbitrary data/state management +- `Set` and `GlobalSet` methods on `*cli.Context` for altering values after +parsing. +- Support for nested lookup of dot-delimited keys in structures loaded from +YAML. + +### Changed +- The `App.Action` and `Command.Action` now prefer a return signature of +`func(*cli.Context) error`, as defined by `cli.ActionFunc`. If a non-nil +`error` is returned, there may be two outcomes: + - If the error fulfills `cli.ExitCoder`, then `os.Exit` will be called + automatically + - Else the error is bubbled up and returned from `App.Run` +- Specifying an `Action` with the legacy return signature of +`func(*cli.Context)` will produce a deprecation message to stderr +- Specifying an `Action` that is not a `func` type will produce a non-zero exit +from `App.Run` +- Specifying an `Action` func that has an invalid (input) signature will +produce a non-zero exit from `App.Run` + +### Deprecated +- <a name="deprecated-cli-app-runandexitonerror"></a> +`cli.App.RunAndExitOnError`, which should now be done by returning an error +that fulfills `cli.ExitCoder` to `cli.App.Run`. +- <a name="deprecated-cli-app-action-signature"></a> the legacy signature for +`cli.App.Action` of `func(*cli.Context)`, which should now have a return +signature of `func(*cli.Context) error`, as defined by `cli.ActionFunc`. + +### Fixed +- Added missing `*cli.Context.GlobalFloat64` method + +## [1.14.0] - 2016-04-03 (backfilled 2016-04-25) +### Added +- Codebeat badge +- Support for categorization via `CategorizedHelp` and `Categories` on app. + +### Changed +- Use `filepath.Base` instead of `path.Base` in `Name` and `HelpName`. + +### Fixed +- Ensure version is not shown in help text when `HideVersion` set. + +## [1.13.0] - 2016-03-06 (backfilled 2016-04-25) +### Added +- YAML file input support. +- `NArg` method on context. + +## [1.12.0] - 2016-02-17 (backfilled 2016-04-25) +### Added +- Custom usage error handling. +- Custom text support in `USAGE` section of help output. +- Improved help messages for empty strings. +- AppVeyor CI configuration. + +### Changed +- Removed `panic` from default help printer func. +- De-duping and optimizations. + +### Fixed +- Correctly handle `Before`/`After` at command level when no subcommands. +- Case of literal `-` argument causing flag reordering. +- Environment variable hints on Windows. +- Docs updates. + +## [1.11.1] - 2015-12-21 (backfilled 2016-04-25) +### Changed +- Use `path.Base` in `Name` and `HelpName` +- Export `GetName` on flag types. + +### Fixed +- Flag parsing when skipping is enabled. +- Test output cleanup. +- Move completion check to account for empty input case. + +## [1.11.0] - 2015-11-15 (backfilled 2016-04-25) +### Added +- Destination scan support for flags. +- Testing against `tip` in Travis CI config. + +### Changed +- Go version in Travis CI config. + +### Fixed +- Removed redundant tests. +- Use correct example naming in tests. + +## [1.10.2] - 2015-10-29 (backfilled 2016-04-25) +### Fixed +- Remove unused var in bash completion. + +## [1.10.1] - 2015-10-21 (backfilled 2016-04-25) +### Added +- Coverage and reference logos in README. + +### Fixed +- Use specified values in help and version parsing. +- Only display app version and help message once. + +## [1.10.0] - 2015-10-06 (backfilled 2016-04-25) +### Added +- More tests for existing functionality. +- `ArgsUsage` at app and command level for help text flexibility. + +### Fixed +- Honor `HideHelp` and `HideVersion` in `App.Run`. +- Remove juvenile word from README. + +## [1.9.0] - 2015-09-08 (backfilled 2016-04-25) +### Added +- `FullName` on command with accompanying help output update. +- Set default `$PROG` in bash completion. + +### Changed +- Docs formatting. + +### Fixed +- Removed self-referential imports in tests. + +## [1.8.0] - 2015-06-30 (backfilled 2016-04-25) +### Added +- Support for `Copyright` at app level. +- `Parent` func at context level to walk up context lineage. + +### Fixed +- Global flag processing at top level. + +## [1.7.1] - 2015-06-11 (backfilled 2016-04-25) +### Added +- Aggregate errors from `Before`/`After` funcs. +- Doc comments on flag structs. +- Include non-global flags when checking version and help. +- Travis CI config updates. + +### Fixed +- Ensure slice type flags have non-nil values. +- Collect global flags from the full command hierarchy. +- Docs prose. + +## [1.7.0] - 2015-05-03 (backfilled 2016-04-25) +### Changed +- `HelpPrinter` signature includes output writer. + +### Fixed +- Specify go 1.1+ in docs. +- Set `Writer` when running command as app. + +## [1.6.0] - 2015-03-23 (backfilled 2016-04-25) +### Added +- Multiple author support. +- `NumFlags` at context level. +- `Aliases` at command level. + +### Deprecated +- `ShortName` at command level. + +### Fixed +- Subcommand help output. +- Backward compatible support for deprecated `Author` and `Email` fields. +- Docs regarding `Names`/`Aliases`. + +## [1.5.0] - 2015-02-20 (backfilled 2016-04-25) +### Added +- `After` hook func support at app and command level. + +### Fixed +- Use parsed context when running command as subcommand. +- Docs prose. + +## [1.4.1] - 2015-01-09 (backfilled 2016-04-25) +### Added +- Support for hiding `-h / --help` flags, but not `help` subcommand. +- Stop flag parsing after `--`. + +### Fixed +- Help text for generic flags to specify single value. +- Use double quotes in output for defaults. +- Use `ParseInt` instead of `ParseUint` for int environment var values. +- Use `0` as base when parsing int environment var values. + +## [1.4.0] - 2014-12-12 (backfilled 2016-04-25) +### Added +- Support for environment variable lookup "cascade". +- Support for `Stdout` on app for output redirection. + +### Fixed +- Print command help instead of app help in `ShowCommandHelp`. + +## [1.3.1] - 2014-11-13 (backfilled 2016-04-25) +### Added +- Docs and example code updates. + +### Changed +- Default `-v / --version` flag made optional. + +## [1.3.0] - 2014-08-10 (backfilled 2016-04-25) +### Added +- `FlagNames` at context level. +- Exposed `VersionPrinter` var for more control over version output. +- Zsh completion hook. +- `AUTHOR` section in default app help template. +- Contribution guidelines. +- `DurationFlag` type. + +## [1.2.0] - 2014-08-02 +### Added +- Support for environment variable defaults on flags plus tests. + +## [1.1.0] - 2014-07-15 +### Added +- Bash completion. +- Optional hiding of built-in help command. +- Optional skipping of flag parsing at command level. +- `Author`, `Email`, and `Compiled` metadata on app. +- `Before` hook func support at app and command level. +- `CommandNotFound` func support at app level. +- Command reference available on context. +- `GenericFlag` type. +- `Float64Flag` type. +- `BoolTFlag` type. +- `IsSet` flag helper on context. +- More flag lookup funcs at context level. +- More tests & docs. + +### Changed +- Help template updates to account for presence/absence of flags. +- Separated subcommand help template. +- Exposed `HelpPrinter` var for more control over help output. + +## [1.0.0] - 2013-11-01 +### Added +- `help` flag in default app flag set and each command flag set. +- Custom handling of argument parsing errors. +- Command lookup by name at app level. +- `StringSliceFlag` type and supporting `StringSlice` type. +- `IntSliceFlag` type and supporting `IntSlice` type. +- Slice type flag lookups by name at context level. +- Export of app and command help functions. +- More tests & docs. + +## 0.1.0 - 2013-07-22 +### Added +- Initial implementation. + +[Unreleased]: https://github.com/codegangsta/cli/compare/v1.17.0...HEAD +[1.17.0]: https://github.com/codegangsta/cli/compare/v1.16.0...v1.17.0 +[1.16.0]: https://github.com/codegangsta/cli/compare/v1.15.0...v1.16.0 +[1.15.0]: https://github.com/codegangsta/cli/compare/v1.14.0...v1.15.0 +[1.14.0]: https://github.com/codegangsta/cli/compare/v1.13.0...v1.14.0 +[1.13.0]: https://github.com/codegangsta/cli/compare/v1.12.0...v1.13.0 +[1.12.0]: https://github.com/codegangsta/cli/compare/v1.11.1...v1.12.0 +[1.11.1]: https://github.com/codegangsta/cli/compare/v1.11.0...v1.11.1 +[1.11.0]: https://github.com/codegangsta/cli/compare/v1.10.2...v1.11.0 +[1.10.2]: https://github.com/codegangsta/cli/compare/v1.10.1...v1.10.2 +[1.10.1]: https://github.com/codegangsta/cli/compare/v1.10.0...v1.10.1 +[1.10.0]: https://github.com/codegangsta/cli/compare/v1.9.0...v1.10.0 +[1.9.0]: https://github.com/codegangsta/cli/compare/v1.8.0...v1.9.0 +[1.8.0]: https://github.com/codegangsta/cli/compare/v1.7.1...v1.8.0 +[1.7.1]: https://github.com/codegangsta/cli/compare/v1.7.0...v1.7.1 +[1.7.0]: https://github.com/codegangsta/cli/compare/v1.6.0...v1.7.0 +[1.6.0]: https://github.com/codegangsta/cli/compare/v1.5.0...v1.6.0 +[1.5.0]: https://github.com/codegangsta/cli/compare/v1.4.1...v1.5.0 +[1.4.1]: https://github.com/codegangsta/cli/compare/v1.4.0...v1.4.1 +[1.4.0]: https://github.com/codegangsta/cli/compare/v1.3.1...v1.4.0 +[1.3.1]: https://github.com/codegangsta/cli/compare/v1.3.0...v1.3.1 +[1.3.0]: https://github.com/codegangsta/cli/compare/v1.2.0...v1.3.0 +[1.2.0]: https://github.com/codegangsta/cli/compare/v1.1.0...v1.2.0 +[1.1.0]: https://github.com/codegangsta/cli/compare/v1.0.0...v1.1.0 +[1.0.0]: https://github.com/codegangsta/cli/compare/v0.1.0...v1.0.0 diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/README.md b/Godeps/_workspace/src/github.com/codegangsta/cli/README.md index d9371cf..a5bab32 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/README.md +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/README.md @@ -1,23 +1,26 @@ -[](http://gocover.io/github.com/codegangsta/cli) [](https://travis-ci.org/codegangsta/cli) [](https://godoc.org/github.com/codegangsta/cli) [](https://codebeat.co/projects/github-com-codegangsta-cli) +[](https://goreportcard.com/report/codegangsta/cli) +[](http://gocover.io/github.com/codegangsta/cli) / +[](http://gocover.io/github.com/codegangsta/cli/altsrc) -# cli.go -`cli.go` is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way. +# cli + +cli is a simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way. ## Overview Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app. -**This is where `cli.go` comes into play.** `cli.go` makes command line programming fun, organized, and expressive! +**This is where cli comes into play.** cli makes command line programming fun, organized, and expressive! ## Installation Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html). -To install `cli.go`, simply run: +To install cli, simply run: ``` $ go get github.com/codegangsta/cli ``` @@ -27,9 +30,49 @@ Make sure your `PATH` includes to the `$GOPATH/bin` directory so your commands c export PATH=$PATH:$GOPATH/bin ``` +### Using the `v2` branch + +There is currently a long-lived branch named `v2` that is intended to land as +the new `master` branch once development there has settled down. The current +`master` branch (mirrored as `v1`) is being manually merged into `v2` on +an irregular human-based schedule, but generally if one wants to "upgrade" to +`v2` *now* and accept the volatility (read: "awesomeness") that comes along with +that, please use whatever version pinning of your preference, such as via +`gopkg.in`: + +``` +$ go get gopkg.in/codegangsta/cli.v2 +``` + +``` go +... +import ( + "gopkg.in/codegangsta/cli.v2" // imports as package "cli" +) +... +``` + +### Pinning to the `v1` branch + +Similarly to the section above describing use of the `v2` branch, if one wants +to avoid any unexpected compatibility pains once `v2` becomes `master`, then +pinning to the `v1` branch is an acceptable option, e.g.: + +``` +$ go get gopkg.in/codegangsta/cli.v1 +``` + +``` go +... +import ( + "gopkg.in/codegangsta/cli.v1" // imports as package "cli" +) +... +``` + ## Getting Started -One of the philosophies behind `cli.go` is that an API should be playful and full of discovery. So a `cli.go` app can be as little as one line of code in `main()`. +One of the philosophies behind cli is that an API should be playful and full of discovery. So a cli app can be as little as one line of code in `main()`. ``` go package main @@ -46,11 +89,16 @@ func main() { This app will run and show help text, but is not very useful. Let's give an action to execute and some help documentation: +<!-- { + "output": "boom! I say!" +} --> ``` go package main import ( + "fmt" "os" + "github.com/codegangsta/cli" ) @@ -58,8 +106,9 @@ func main() { app := cli.NewApp() app.Name = "boom" app.Usage = "make an explosive entrance" - app.Action = func(c *cli.Context) { - println("boom! I say!") + app.Action = func(c *cli.Context) error { + fmt.Println("boom! I say!") + return nil } app.Run(os.Args) @@ -74,11 +123,16 @@ Being a programmer can be a lonely job. Thankfully by the power of automation th Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it: +<!-- { + "output": "Hello friend!" +} --> ``` go package main import ( + "fmt" "os" + "github.com/codegangsta/cli" ) @@ -86,8 +140,9 @@ func main() { app := cli.NewApp() app.Name = "greet" app.Usage = "fight the loneliness!" - app.Action = func(c *cli.Context) { - println("Hello friend!") + app.Action = func(c *cli.Context) error { + fmt.Println("Hello friend!") + return nil } app.Run(os.Args) @@ -107,7 +162,7 @@ $ greet Hello friend! ``` -`cli.go` also generates neat help text: +cli also generates neat help text: ``` $ greet help @@ -133,8 +188,9 @@ You can lookup arguments by calling the `Args` function on `cli.Context`. ``` go ... -app.Action = func(c *cli.Context) { - println("Hello", c.Args()[0]) +app.Action = func(c *cli.Context) error { + fmt.Println("Hello", c.Args()[0]) + return nil } ... ``` @@ -152,16 +208,17 @@ app.Flags = []cli.Flag { Usage: "language for the greeting", }, } -app.Action = func(c *cli.Context) { +app.Action = func(c *cli.Context) error { name := "someone" if c.NArg() > 0 { name = c.Args()[0] } if c.String("lang") == "spanish" { - println("Hola", name) + fmt.Println("Hola", name) } else { - println("Hello", name) + fmt.Println("Hello", name) } + return nil } ... ``` @@ -179,22 +236,45 @@ app.Flags = []cli.Flag { Destination: &language, }, } -app.Action = func(c *cli.Context) { +app.Action = func(c *cli.Context) error { name := "someone" if c.NArg() > 0 { name = c.Args()[0] } if language == "spanish" { - println("Hola", name) + fmt.Println("Hola", name) } else { - println("Hello", name) + fmt.Println("Hello", name) } + return nil } ... ``` See full list of flags at http://godoc.org/github.com/codegangsta/cli +#### Placeholder Values + +Sometimes it's useful to specify a flag's value within the usage string itself. Such placeholders are +indicated with back quotes. + +For example this: + +```go +cli.StringFlag{ + Name: "config, c", + Usage: "Load configuration from `FILE`", +} +``` + +Will result in help output like: + +``` +--config FILE, -c FILE Load configuration from FILE +``` + +Note that only the first placeholder is used. Subsequent back-quoted words will be left as-is. + #### Alternate Names You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g. @@ -255,8 +335,8 @@ Initialization must also occur for these flags. Below is an example initializing command.Before = altsrc.InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load")) ``` -The code above will use the "load" string as a flag name to get the file name of a yaml file from the cli.Context. -It will then use that file name to initialize the yaml input source for any flags that are defined on that command. +The code above will use the "load" string as a flag name to get the file name of a yaml file from the cli.Context. +It will then use that file name to initialize the yaml input source for any flags that are defined on that command. As a note the "load" flag used would also have to be defined on the command flags in order for this code snipped to work. Currently only YAML files are supported but developers can add support for other input sources by implementing the @@ -265,20 +345,21 @@ altsrc.InputSourceContext for their given sources. Here is a more complete sample of a command using YAML support: ``` go - command := &cli.Command{ - Name: "test-cmd", - Aliases: []string{"tc"}, - Usage: "this is for testing", - Description: "testing", - Action: func(c *cli.Context) { - // Action to run - }, - Flags: []cli.Flag{ - NewIntFlag(cli.IntFlag{Name: "test"}), - cli.StringFlag{Name: "load"}}, - } - command.Before = InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load")) - err := command.Run(c) + command := &cli.Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(c *cli.Context) error { + // Action to run + return nil + }, + Flags: []cli.Flag{ + NewIntFlag(cli.IntFlag{Name: "test"}), + cli.StringFlag{Name: "load"}}, + } + command.Before = InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load")) + err := command.Run(c) ``` ### Subcommands @@ -292,16 +373,18 @@ app.Commands = []cli.Command{ Name: "add", Aliases: []string{"a"}, Usage: "add a task to the list", - Action: func(c *cli.Context) { - println("added task: ", c.Args().First()) + Action: func(c *cli.Context) error { + fmt.Println("added task: ", c.Args().First()) + return nil }, }, { Name: "complete", Aliases: []string{"c"}, Usage: "complete a task on the list", - Action: func(c *cli.Context) { - println("completed task: ", c.Args().First()) + Action: func(c *cli.Context) error { + fmt.Println("completed task: ", c.Args().First()) + return nil }, }, { @@ -312,15 +395,17 @@ app.Commands = []cli.Command{ { Name: "add", Usage: "add a new template", - Action: func(c *cli.Context) { - println("new task template: ", c.Args().First()) + Action: func(c *cli.Context) error { + fmt.Println("new task template: ", c.Args().First()) + return nil }, }, { Name: "remove", Usage: "remove an existing template", - Action: func(c *cli.Context) { - println("removed task template: ", c.Args().First()) + Action: func(c *cli.Context) error { + fmt.Println("removed task template: ", c.Args().First()) + return nil }, }, }, @@ -339,19 +424,19 @@ E.g. ```go ... - app.Commands = []cli.Command{ - { - Name: "noop", - }, - { - Name: "add", - Category: "template", - }, - { - Name: "remove", - Category: "template", - }, - } + app.Commands = []cli.Command{ + { + Name: "noop", + }, + { + Name: "add", + Category: "template", + }, + { + Name: "remove", + Category: "template", + }, + } ... ``` @@ -368,6 +453,41 @@ COMMANDS: ... ``` +### Exit code + +Calling `App.Run` will not automatically call `os.Exit`, which means that by +default the exit code will "fall through" to being `0`. An explicit exit code +may be set by returning a non-nil error that fulfills `cli.ExitCoder`, *or* a +`cli.MultiError` that includes an error that fulfills `cli.ExitCoder`, e.g.: + +``` go +package main + +import ( + "os" + + "github.com/codegangsta/cli" +) + +func main() { + app := cli.NewApp() + app.Flags = []cli.Flag{ + cli.BoolTFlag{ + Name: "ginger-crouton", + Usage: "is it in the soup?", + }, + } + app.Action = func(ctx *cli.Context) error { + if !ctx.Bool("ginger-crouton") { + return cli.NewExitError("it is not in the soup", 86) + } + return nil + } + + app.Run(os.Args) +} +``` + ### Bash Completion You can enable completion commands by setting the `EnableBashCompletion` @@ -385,8 +505,9 @@ app.Commands = []cli.Command{ Name: "complete", Aliases: []string{"c"}, Usage: "complete a task on the list", - Action: func(c *cli.Context) { - println("completed task: ", c.Args().First()) + Action: func(c *cli.Context) error { + fmt.Println("completed task: ", c.Args().First()) + return nil }, BashComplete: func(c *cli.Context) { // This will complete if no args are passed @@ -425,6 +546,72 @@ Alternatively, you can just document that users should source the generic `autocomplete/bash_autocomplete` in their bash configuration with `$PROG` set to the name of their program (as above). +### Generated Help Text Customization + +All of the help text generation may be customized, and at multiple levels. The +templates are exposed as variables `AppHelpTemplate`, `CommandHelpTemplate`, and +`SubcommandHelpTemplate` which may be reassigned or augmented, and full override +is possible by assigning a compatible func to the `cli.HelpPrinter` variable, +e.g.: + +<!-- { + "output": "Ha HA. I pwnd the help!!1" +} --> +``` go +package main + +import ( + "fmt" + "io" + "os" + + "github.com/codegangsta/cli" +) + +func main() { + // EXAMPLE: Append to an existing template + cli.AppHelpTemplate = fmt.Sprintf(`%s + +WEBSITE: http://awesometown.example.com + +SUPPORT: support@awesometown.example.com + +`, cli.AppHelpTemplate) + + // EXAMPLE: Override a template + cli.AppHelpTemplate = `NAME: + {{.Name}} - {{.Usage}} +USAGE: + {{.HelpName}} {{if .VisibleFlags}}[global options]{{end}}{{if .Commands}} command +[command options]{{end}} {{if +.ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}} + {{if len .Authors}} +AUTHOR(S): + {{range .Authors}}{{ . }}{{end}} + {{end}}{{if .Commands}} +COMMANDS: +{{range .Commands}}{{if not .HideHelp}} {{join .Names ", "}}{{ "\t" +}}{{.Usage}}{{ "\n" }}{{end}}{{end}}{{end}}{{if .VisibleFlags}} +GLOBAL OPTIONS: + {{range .VisibleFlags}}{{.}} + {{end}}{{end}}{{if .Copyright }} +COPYRIGHT: + {{.Copyright}} + {{end}}{{if .Version}} +VERSION: + {{.Version}} + {{end}} +` + + // EXAMPLE: Replace the `HelpPrinter` func + cli.HelpPrinter = func(w io.Writer, templ string, data interface{}) { + fmt.Println("Ha HA. I pwnd the help!!1") + } + + cli.NewApp().Run(os.Args) +} +``` + ## Contribution Guidelines Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch. diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/app.go b/Godeps/_workspace/src/github.com/codegangsta/cli/app.go index bd20a2d..7c9b958 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/app.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/app.go @@ -6,10 +6,26 @@ import ( "io/ioutil" "os" "path/filepath" + "reflect" "sort" "time" ) +var ( + changeLogURL = "https://github.com/codegangsta/cli/blob/master/CHANGELOG.md" + appActionDeprecationURL = fmt.Sprintf("%s#deprecated-cli-app-action-signature", changeLogURL) + runAndExitOnErrorDeprecationURL = fmt.Sprintf("%s#deprecated-cli-app-runandexitonerror", changeLogURL) + + contactSysadmin = "This is an error in the application. Please contact the distributor of this application if this is not you." + + errNonFuncAction = NewExitError("ERROR invalid Action type. "+ + fmt.Sprintf("Must be a func of type `cli.ActionFunc`. %s", contactSysadmin)+ + fmt.Sprintf("See %s", appActionDeprecationURL), 2) + errInvalidActionSignature = NewExitError("ERROR invalid Action signature. "+ + fmt.Sprintf("Must be `cli.ActionFunc`. %s", contactSysadmin)+ + fmt.Sprintf("See %s", appActionDeprecationURL), 2) +) + // App is the main structure of a cli application. It is recommended that // an app be created with the cli.NewApp() function type App struct { @@ -35,24 +51,25 @@ type App struct { HideHelp bool // Boolean to hide built-in version flag and the VERSION section of help HideVersion bool - // Populate on app startup, only gettable throught method Categories() + // Populate on app startup, only gettable through method Categories() categories CommandCategories // An action to execute when the bash-completion flag is set - BashComplete func(context *Context) + BashComplete BashCompleteFunc // An action to execute before any subcommands are run, but after the context is ready // If a non-nil error is returned, no subcommands are run - Before func(context *Context) error + Before BeforeFunc // An action to execute after any subcommands are run, but after the subcommand has finished // It is run even if Action() panics - After func(context *Context) error + After AfterFunc // The action to execute when no subcommands are specified - Action func(context *Context) + Action interface{} + // TODO: replace `Action: interface{}` with `Action: ActionFunc` once some kind + // of deprecation period has passed, maybe? + // Execute this function if the proper command cannot be found - CommandNotFound func(context *Context, command string) - // Execute this function, if an usage error occurs. This is useful for displaying customized usage error messages. - // This function is able to replace the original error messages. - // If this function is not set, the "Incorrect usage" is displayed and the execution is interrupted. - OnUsageError func(context *Context, err error, isSubcommand bool) error + CommandNotFound CommandNotFoundFunc + // Execute this function if an usage error occurs + OnUsageError OnUsageErrorFunc // Compilation date Compiled time.Time // List of all authors who contributed @@ -65,6 +82,12 @@ type App struct { Email string // Writer writer to write output to Writer io.Writer + // ErrWriter writes error output + ErrWriter io.Writer + // Other custom info + Metadata map[string]interface{} + + didSetup bool } // Tries to find out when this binary was compiled. @@ -77,7 +100,8 @@ func compileTime() time.Time { return info.ModTime() } -// Creates a new cli Application with some reasonable defaults for Name, Usage, Version and Action. +// NewApp creates a new cli Application with some reasonable defaults for Name, +// Usage, Version and Action. func NewApp() *App { return &App{ Name: filepath.Base(os.Args[0]), @@ -92,8 +116,16 @@ func NewApp() *App { } } -// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination -func (a *App) Run(arguments []string) (err error) { +// Setup runs initialization code to ensure all data structures are ready for +// `Run` or inspection prior to `Run`. It is internally called by `Run`, but +// will return early if setup has already happened. +func (a *App) Setup() { + if a.didSetup { + return + } + + a.didSetup = true + if a.Author != "" || a.Email != "" { a.Authors = append(a.Authors, Author{Name: a.Author, Email: a.Email}) } @@ -129,6 +161,12 @@ func (a *App) Run(arguments []string) (err error) { if !a.HideVersion { a.appendFlag(VersionFlag) } +} + +// Run is the entry point to the cli app. Parses the arguments slice and routes +// to the proper flag/args combination +func (a *App) Run(arguments []string) (err error) { + a.Setup() // parse flags set := flagSet(a.Name, a.Flags) @@ -149,12 +187,12 @@ func (a *App) Run(arguments []string) (err error) { if err != nil { if a.OnUsageError != nil { err := a.OnUsageError(context, err, false) - return err - } else { - fmt.Fprintf(a.Writer, "%s\n\n", "Incorrect Usage.") - ShowAppHelp(context) + HandleExitCoder(err) return err } + fmt.Fprintf(a.Writer, "%s\n\n", "Incorrect Usage.") + ShowAppHelp(context) + return err } if !a.HideHelp && checkHelp(context) { @@ -180,10 +218,12 @@ func (a *App) Run(arguments []string) (err error) { } if a.Before != nil { - err = a.Before(context) - if err != nil { - fmt.Fprintf(a.Writer, "%v\n\n", err) + beforeErr := a.Before(context) + if beforeErr != nil { + fmt.Fprintf(a.Writer, "%v\n\n", beforeErr) ShowAppHelp(context) + HandleExitCoder(beforeErr) + err = beforeErr return err } } @@ -198,19 +238,25 @@ func (a *App) Run(arguments []string) (err error) { } // Run default Action - a.Action(context) - return nil + err = HandleAction(a.Action, context) + + HandleExitCoder(err) + return err } -// Another entry point to the cli app, takes care of passing arguments and error handling +// DEPRECATED: Another entry point to the cli app, takes care of passing arguments and error handling func (a *App) RunAndExitOnError() { + fmt.Fprintf(a.errWriter(), + "DEPRECATED cli.App.RunAndExitOnError. %s See %s\n", + contactSysadmin, runAndExitOnErrorDeprecationURL) if err := a.Run(os.Args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + fmt.Fprintln(a.errWriter(), err) + OsExiter(1) } } -// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags +// RunAsSubcommand invokes the subcommand given the context, parses ctx.Args() to +// generate command-specific flags func (a *App) RunAsSubcommand(ctx *Context) (err error) { // append help to commands if len(a.Commands) > 0 { @@ -261,12 +307,12 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { if err != nil { if a.OnUsageError != nil { err = a.OnUsageError(context, err, true) - return err - } else { - fmt.Fprintf(a.Writer, "%s\n\n", "Incorrect Usage.") - ShowSubcommandHelp(context) + HandleExitCoder(err) return err } + fmt.Fprintf(a.Writer, "%s\n\n", "Incorrect Usage.") + ShowSubcommandHelp(context) + return err } if len(a.Commands) > 0 { @@ -283,6 +329,7 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { defer func() { afterErr := a.After(context) if afterErr != nil { + HandleExitCoder(err) if err != nil { err = NewMultiError(err, afterErr) } else { @@ -293,8 +340,10 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { } if a.Before != nil { - err := a.Before(context) - if err != nil { + beforeErr := a.Before(context) + if beforeErr != nil { + HandleExitCoder(beforeErr) + err = beforeErr return err } } @@ -309,12 +358,13 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { } // Run default Action - a.Action(context) + err = HandleAction(a.Action, context) - return nil + HandleExitCoder(err) + return err } -// Returns the named command on App. Returns nil if the command does not exist +// Command returns the named command on App. Returns nil if the command does not exist func (a *App) Command(name string) *Command { for _, c := range a.Commands { if c.HasName(name) { @@ -325,11 +375,46 @@ func (a *App) Command(name string) *Command { return nil } -// Returnes the array containing all the categories with the commands they contain +// Categories returns a slice containing all the categories with the commands they contain func (a *App) Categories() CommandCategories { return a.categories } +// VisibleCategories returns a slice of categories and commands that are +// Hidden=false +func (a *App) VisibleCategories() []*CommandCategory { + ret := []*CommandCategory{} + for _, category := range a.categories { + if visible := func() *CommandCategory { + for _, command := range category.Commands { + if !command.Hidden { + return category + } + } + return nil + }(); visible != nil { + ret = append(ret, visible) + } + } + return ret +} + +// VisibleCommands returns a slice of the Commands with Hidden=false +func (a *App) VisibleCommands() []Command { + ret := []Command{} + for _, command := range a.Commands { + if !command.Hidden { + ret = append(ret, command) + } + } + return ret +} + +// VisibleFlags returns a slice of the Flags with Hidden=false +func (a *App) VisibleFlags() []Flag { + return visibleFlags(a.Flags) +} + func (a *App) hasFlag(flag Flag) bool { for _, f := range a.Flags { if flag == f { @@ -340,6 +425,16 @@ func (a *App) hasFlag(flag Flag) bool { return false } +func (a *App) errWriter() io.Writer { + + // When the app ErrWriter is nil use the package level one. + if a.ErrWriter == nil { + return ErrWriter + } + + return a.ErrWriter +} + func (a *App) appendFlag(flag Flag) { if !a.hasFlag(flag) { a.Flags = append(a.Flags, flag) @@ -361,3 +456,43 @@ func (a Author) String() string { return fmt.Sprintf("%v %v", a.Name, e) } + +// HandleAction uses ✧✧✧reflection✧✧✧ to figure out if the given Action is an +// ActionFunc, a func with the legacy signature for Action, or some other +// invalid thing. If it's an ActionFunc or a func with the legacy signature for +// Action, the func is run! +func HandleAction(action interface{}, context *Context) (err error) { + defer func() { + if r := recover(); r != nil { + switch r.(type) { + case error: + err = r.(error) + default: + err = NewExitError(fmt.Sprintf("ERROR unknown Action error: %v. See %s", r, appActionDeprecationURL), 2) + } + } + }() + + if reflect.TypeOf(action).Kind() != reflect.Func { + return errNonFuncAction + } + + vals := reflect.ValueOf(action).Call([]reflect.Value{reflect.ValueOf(context)}) + + if len(vals) == 0 { + fmt.Fprintf(ErrWriter, + "DEPRECATED Action signature. Must be `cli.ActionFunc`. %s See %s\n", + contactSysadmin, appActionDeprecationURL) + return nil + } + + if len(vals) > 1 { + return errInvalidActionSignature + } + + if retErr, ok := vals[0].Interface().(error); vals[0].IsValid() && ok { + return retErr + } + + return err +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/category.go b/Godeps/_workspace/src/github.com/codegangsta/cli/category.go index 7dbf218..1a60550 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/category.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/category.go @@ -1,7 +1,9 @@ package cli +// CommandCategories is a slice of *CommandCategory. type CommandCategories []*CommandCategory +// CommandCategory is a category containing commands. type CommandCategory struct { Name string Commands Commands @@ -19,6 +21,7 @@ func (c CommandCategories) Swap(i, j int) { c[i], c[j] = c[j], c[i] } +// AddCommand adds a command to a category. func (c CommandCategories) AddCommand(category string, command Command) CommandCategories { for _, commandCategory := range c { if commandCategory.Name == category { @@ -28,3 +31,14 @@ func (c CommandCategories) AddCommand(category string, command Command) CommandC } return append(c, &CommandCategory{Name: category, Commands: []Command{command}}) } + +// VisibleCommands returns a slice of the Commands with Hidden=false +func (c *CommandCategory) VisibleCommands() []Command { + ret := []Command{} + for _, command := range c.Commands { + if !command.Hidden { + ret = append(ret, command) + } + } + return ret +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go b/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go index 31dc912..f0440c5 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go @@ -10,31 +10,10 @@ // app := cli.NewApp() // app.Name = "greet" // app.Usage = "say a greeting" -// app.Action = func(c *cli.Context) { +// app.Action = func(c *cli.Context) error { // println("Greetings") // } // // app.Run(os.Args) // } package cli - -import ( - "strings" -) - -type MultiError struct { - Errors []error -} - -func NewMultiError(err ...error) MultiError { - return MultiError{Errors: err} -} - -func (m MultiError) Error() string { - errs := make([]string, len(m.Errors)) - for i, err := range m.Errors { - errs[i] = err.Error() - } - - return strings.Join(errs, "\n") -} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/command.go b/Godeps/_workspace/src/github.com/codegangsta/cli/command.go index 1a05b54..8950cca 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/command.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/command.go @@ -26,19 +26,20 @@ type Command struct { // The category the command is part of Category string // The function to call when checking for bash command completions - BashComplete func(context *Context) + BashComplete BashCompleteFunc // An action to execute before any sub-subcommands are run, but after the context is ready // If a non-nil error is returned, no sub-subcommands are run - Before func(context *Context) error - // An action to execute after any subcommands are run, but before the subcommand has finished + Before BeforeFunc + // An action to execute after any subcommands are run, but after the subcommand has finished // It is run even if Action() panics - After func(context *Context) error + After AfterFunc // The function to call when this command is invoked - Action func(context *Context) - // Execute this function, if an usage error occurs. This is useful for displaying customized usage error messages. - // This function is able to replace the original error messages. - // If this function is not set, the "Incorrect usage" is displayed and the execution is interrupted. - OnUsageError func(context *Context, err error) error + Action interface{} + // TODO: replace `Action: interface{}` with `Action: ActionFunc` once some kind + // of deprecation period has passed, maybe? + + // Execute this function if a usage error occurs. + OnUsageError OnUsageErrorFunc // List of child commands Subcommands Commands // List of flags to parse @@ -47,13 +48,15 @@ type Command struct { SkipFlagParsing bool // Boolean to hide built-in help command HideHelp bool + // Boolean to hide this command from help or completion + Hidden bool // Full name of command for help, defaults to full command name, including parent commands. HelpName string commandNamePath []string } -// Returns the full name of the command. +// FullName returns the full name of the command. // For subcommands this ensures that parent commands are part of the command path func (c Command) FullName() string { if c.commandNamePath == nil { @@ -62,9 +65,10 @@ func (c Command) FullName() string { return strings.Join(c.commandNamePath, " ") } +// Commands is a slice of Command type Commands []Command -// Invokes the command given the context, parses ctx.Args() to generate command-specific flags +// Run invokes the command given the context, parses ctx.Args() to generate command-specific flags func (c Command) Run(ctx *Context) (err error) { if len(c.Subcommands) > 0 { return c.startApp(ctx) @@ -125,14 +129,14 @@ func (c Command) Run(ctx *Context) (err error) { if err != nil { if c.OnUsageError != nil { - err := c.OnUsageError(ctx, err) - return err - } else { - fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.") - fmt.Fprintln(ctx.App.Writer) - ShowCommandHelp(ctx, c.Name) + err := c.OnUsageError(ctx, err, false) + HandleExitCoder(err) return err } + fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.") + fmt.Fprintln(ctx.App.Writer) + ShowCommandHelp(ctx, c.Name) + return err } nerr := normalizeFlags(c.Flags, set) @@ -142,6 +146,7 @@ func (c Command) Run(ctx *Context) (err error) { ShowCommandHelp(ctx, c.Name) return nerr } + context := NewContext(ctx.App, set, ctx) if checkCommandCompletions(context, c.Name) { @@ -156,6 +161,7 @@ func (c Command) Run(ctx *Context) (err error) { defer func() { afterErr := c.After(context) if afterErr != nil { + HandleExitCoder(err) if err != nil { err = NewMultiError(err, afterErr) } else { @@ -166,20 +172,26 @@ func (c Command) Run(ctx *Context) (err error) { } if c.Before != nil { - err := c.Before(context) + err = c.Before(context) if err != nil { fmt.Fprintln(ctx.App.Writer, err) fmt.Fprintln(ctx.App.Writer) ShowCommandHelp(ctx, c.Name) + HandleExitCoder(err) return err } } context.Command = c - c.Action(context) - return nil + err = HandleAction(c.Action, context) + + if err != nil { + HandleExitCoder(err) + } + return err } +// Names returns the names including short names and aliases. func (c Command) Names() []string { names := []string{c.Name} @@ -190,7 +202,7 @@ func (c Command) Names() []string { return append(names, c.Aliases...) } -// Returns true if Command.Name or Command.ShortName matches given name +// HasName returns true if Command.Name or Command.ShortName matches given name func (c Command) HasName(name string) bool { for _, n := range c.Names() { if n == name { @@ -202,7 +214,7 @@ func (c Command) HasName(name string) bool { func (c Command) startApp(ctx *Context) error { app := NewApp() - + app.Metadata = ctx.App.Metadata // set the name and usage app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name) if c.HelpName == "" { @@ -260,3 +272,8 @@ func (c Command) startApp(ctx *Context) error { return app.RunAsSubcommand(ctx) } + +// VisibleFlags returns a slice of the Flags with Hidden=false +func (c Command) VisibleFlags() []Flag { + return visibleFlags(c.Flags) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/context.go b/Godeps/_workspace/src/github.com/codegangsta/cli/context.go index b66d278..c342463 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/context.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/context.go @@ -21,57 +21,62 @@ type Context struct { parentContext *Context } -// Creates a new context. For use in when invoking an App or Command action. +// NewContext creates a new context. For use in when invoking an App or Command action. func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context { return &Context{App: app, flagSet: set, parentContext: parentCtx} } -// Looks up the value of a local int flag, returns 0 if no int flag exists +// Int looks up the value of a local int flag, returns 0 if no int flag exists func (c *Context) Int(name string) int { return lookupInt(name, c.flagSet) } -// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists +// Duration looks up the value of a local time.Duration flag, returns 0 if no +// time.Duration flag exists func (c *Context) Duration(name string) time.Duration { return lookupDuration(name, c.flagSet) } -// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists +// Float64 looks up the value of a local float64 flag, returns 0 if no float64 +// flag exists func (c *Context) Float64(name string) float64 { return lookupFloat64(name, c.flagSet) } -// Looks up the value of a local bool flag, returns false if no bool flag exists +// Bool looks up the value of a local bool flag, returns false if no bool flag exists func (c *Context) Bool(name string) bool { return lookupBool(name, c.flagSet) } -// Looks up the value of a local boolT flag, returns false if no bool flag exists +// BoolT looks up the value of a local boolT flag, returns false if no bool flag exists func (c *Context) BoolT(name string) bool { return lookupBoolT(name, c.flagSet) } -// Looks up the value of a local string flag, returns "" if no string flag exists +// String looks up the value of a local string flag, returns "" if no string flag exists func (c *Context) String(name string) string { return lookupString(name, c.flagSet) } -// Looks up the value of a local string slice flag, returns nil if no string slice flag exists +// StringSlice looks up the value of a local string slice flag, returns nil if no +// string slice flag exists func (c *Context) StringSlice(name string) []string { return lookupStringSlice(name, c.flagSet) } -// Looks up the value of a local int slice flag, returns nil if no int slice flag exists +// IntSlice looks up the value of a local int slice flag, returns nil if no int +// slice flag exists func (c *Context) IntSlice(name string) []int { return lookupIntSlice(name, c.flagSet) } -// Looks up the value of a local generic flag, returns nil if no generic flag exists +// Generic looks up the value of a local generic flag, returns nil if no generic +// flag exists func (c *Context) Generic(name string) interface{} { return lookupGeneric(name, c.flagSet) } -// Looks up the value of a global int flag, returns 0 if no int flag exists +// GlobalInt looks up the value of a global int flag, returns 0 if no int flag exists func (c *Context) GlobalInt(name string) int { if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupInt(name, fs) @@ -79,7 +84,17 @@ func (c *Context) GlobalInt(name string) int { return 0 } -// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists +// GlobalFloat64 looks up the value of a global float64 flag, returns float64(0) +// if no float64 flag exists +func (c *Context) GlobalFloat64(name string) float64 { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupFloat64(name, fs) + } + return float64(0) +} + +// GlobalDuration looks up the value of a global time.Duration flag, returns 0 +// if no time.Duration flag exists func (c *Context) GlobalDuration(name string) time.Duration { if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupDuration(name, fs) @@ -87,7 +102,8 @@ func (c *Context) GlobalDuration(name string) time.Duration { return 0 } -// Looks up the value of a global bool flag, returns false if no bool flag exists +// GlobalBool looks up the value of a global bool flag, returns false if no bool +// flag exists func (c *Context) GlobalBool(name string) bool { if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupBool(name, fs) @@ -95,7 +111,17 @@ func (c *Context) GlobalBool(name string) bool { return false } -// Looks up the value of a global string flag, returns "" if no string flag exists +// GlobalBoolT looks up the value of a global bool flag, returns true if no bool +// flag exists +func (c *Context) GlobalBoolT(name string) bool { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupBoolT(name, fs) + } + return false +} + +// GlobalString looks up the value of a global string flag, returns "" if no +// string flag exists func (c *Context) GlobalString(name string) string { if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupString(name, fs) @@ -103,7 +129,8 @@ func (c *Context) GlobalString(name string) string { return "" } -// Looks up the value of a global string slice flag, returns nil if no string slice flag exists +// GlobalStringSlice looks up the value of a global string slice flag, returns +// nil if no string slice flag exists func (c *Context) GlobalStringSlice(name string) []string { if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupStringSlice(name, fs) @@ -111,7 +138,8 @@ func (c *Context) GlobalStringSlice(name string) []string { return nil } -// Looks up the value of a global int slice flag, returns nil if no int slice flag exists +// GlobalIntSlice looks up the value of a global int slice flag, returns nil if +// no int slice flag exists func (c *Context) GlobalIntSlice(name string) []int { if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupIntSlice(name, fs) @@ -119,7 +147,8 @@ func (c *Context) GlobalIntSlice(name string) []int { return nil } -// Looks up the value of a global generic flag, returns nil if no generic flag exists +// GlobalGeneric looks up the value of a global generic flag, returns nil if no +// generic flag exists func (c *Context) GlobalGeneric(name string) interface{} { if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupGeneric(name, fs) @@ -127,12 +156,22 @@ func (c *Context) GlobalGeneric(name string) interface{} { return nil } -// Returns the number of flags set +// NumFlags returns the number of flags set func (c *Context) NumFlags() int { return c.flagSet.NFlag() } -// Determines if the flag was actually set +// Set sets a context flag to a value. +func (c *Context) Set(name, value string) error { + return c.flagSet.Set(name, value) +} + +// GlobalSet sets a context flag to a value on the global flagset +func (c *Context) GlobalSet(name, value string) error { + return globalContext(c).flagSet.Set(name, value) +} + +// IsSet determines if the flag was actually set func (c *Context) IsSet(name string) bool { if c.setFlags == nil { c.setFlags = make(map[string]bool) @@ -143,7 +182,7 @@ func (c *Context) IsSet(name string) bool { return c.setFlags[name] == true } -// Determines if the global flag was actually set +// GlobalIsSet determines if the global flag was actually set func (c *Context) GlobalIsSet(name string) bool { if c.globalSetFlags == nil { c.globalSetFlags = make(map[string]bool) @@ -160,7 +199,7 @@ func (c *Context) GlobalIsSet(name string) bool { return c.globalSetFlags[name] } -// Returns a slice of flag names used in this context. +// FlagNames returns a slice of flag names used in this context. func (c *Context) FlagNames() (names []string) { for _, flag := range c.Command.Flags { name := strings.Split(flag.GetName(), ",")[0] @@ -172,7 +211,7 @@ func (c *Context) FlagNames() (names []string) { return } -// Returns a slice of global flag names used by the app. +// GlobalFlagNames returns a slice of global flag names used by the app. func (c *Context) GlobalFlagNames() (names []string) { for _, flag := range c.App.Flags { name := strings.Split(flag.GetName(), ",")[0] @@ -184,25 +223,26 @@ func (c *Context) GlobalFlagNames() (names []string) { return } -// Returns the parent context, if any +// Parent returns the parent context, if any func (c *Context) Parent() *Context { return c.parentContext } +// Args contains apps console arguments type Args []string -// Returns the command line arguments associated with the context. +// Args returns the command line arguments associated with the context. func (c *Context) Args() Args { args := Args(c.flagSet.Args()) return args } -// Returns the number of the command line arguments. +// NArg returns the number of the command line arguments. func (c *Context) NArg() int { return len(c.Args()) } -// Returns the nth argument, or else a blank string +// Get returns the nth argument, or else a blank string func (a Args) Get(n int) string { if len(a) > n { return a[n] @@ -210,12 +250,12 @@ func (a Args) Get(n int) string { return "" } -// Returns the first argument, or else a blank string +// First returns the first argument, or else a blank string func (a Args) First() string { return a.Get(0) } -// Return the rest of the arguments (not the first one) +// Tail returns the rest of the arguments (not the first one) // or else an empty string slice func (a Args) Tail() []string { if len(a) >= 2 { @@ -224,12 +264,12 @@ func (a Args) Tail() []string { return []string{} } -// Checks if there are any arguments present +// Present checks if there are any arguments present func (a Args) Present() bool { return len(a) != 0 } -// Swaps arguments at the given indexes +// Swap swaps arguments at the given indexes func (a Args) Swap(from, to int) error { if from >= len(a) || to >= len(a) { return errors.New("index out of range") @@ -238,6 +278,19 @@ func (a Args) Swap(from, to int) error { return nil } +func globalContext(ctx *Context) *Context { + if ctx == nil { + return nil + } + + for { + if ctx.parentContext == nil { + return ctx + } + ctx = ctx.parentContext + } +} + func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet { if ctx.parentContext != nil { ctx = ctx.parentContext diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/errors.go b/Godeps/_workspace/src/github.com/codegangsta/cli/errors.go new file mode 100644 index 0000000..ea551be --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/errors.go @@ -0,0 +1,92 @@ +package cli + +import ( + "fmt" + "io" + "os" + "strings" +) + +// OsExiter is the function used when the app exits. If not set defaults to os.Exit. +var OsExiter = os.Exit + +// ErrWriter is used to write errors to the user. This can be anything +// implementing the io.Writer interface and defaults to os.Stderr. +var ErrWriter io.Writer = os.Stderr + +// MultiError is an error that wraps multiple errors. +type MultiError struct { + Errors []error +} + +// NewMultiError creates a new MultiError. Pass in one or more errors. +func NewMultiError(err ...error) MultiError { + return MultiError{Errors: err} +} + +// Error implents the error interface. +func (m MultiError) Error() string { + errs := make([]string, len(m.Errors)) + for i, err := range m.Errors { + errs[i] = err.Error() + } + + return strings.Join(errs, "\n") +} + +// ExitCoder is the interface checked by `App` and `Command` for a custom exit +// code +type ExitCoder interface { + error + ExitCode() int +} + +// ExitError fulfills both the builtin `error` interface and `ExitCoder` +type ExitError struct { + exitCode int + message string +} + +// NewExitError makes a new *ExitError +func NewExitError(message string, exitCode int) *ExitError { + return &ExitError{ + exitCode: exitCode, + message: message, + } +} + +// Error returns the string message, fulfilling the interface required by +// `error` +func (ee *ExitError) Error() string { + return ee.message +} + +// ExitCode returns the exit code, fulfilling the interface required by +// `ExitCoder` +func (ee *ExitError) ExitCode() int { + return ee.exitCode +} + +// HandleExitCoder checks if the error fulfills the ExitCoder interface, and if +// so prints the error to stderr (if it is non-empty) and calls OsExiter with the +// given exit code. If the given error is a MultiError, then this func is +// called on all members of the Errors slice. +func HandleExitCoder(err error) { + if err == nil { + return + } + + if exitErr, ok := err.(ExitCoder); ok { + if err.Error() != "" { + fmt.Fprintln(ErrWriter, err) + } + OsExiter(exitErr.ExitCode()) + return + } + + if multiErr, ok := err.(MultiError); ok { + for _, merr := range multiErr.Errors { + HandleExitCoder(merr) + } + } +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go b/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go index e951c2d..1e8112e 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go @@ -4,24 +4,28 @@ import ( "flag" "fmt" "os" + "reflect" "runtime" "strconv" "strings" "time" ) -// This flag enables bash-completion for all commands and subcommands +const defaultPlaceholder = "value" + +// BashCompletionFlag enables bash-completion for all commands and subcommands var BashCompletionFlag = BoolFlag{ - Name: "generate-bash-completion", + Name: "generate-bash-completion", + Hidden: true, } -// This flag prints the version for the application +// VersionFlag prints the version for the application var VersionFlag = BoolFlag{ Name: "version, v", Usage: "print the version", } -// This flag prints the help for all commands and subcommands +// HelpFlag prints the help for all commands and subcommands // Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand // unless HideHelp is set to true) var HelpFlag = BoolFlag{ @@ -29,6 +33,10 @@ var HelpFlag = BoolFlag{ Usage: "show help", } +// FlagStringer converts a flag definition to a string. This is used by help +// to display a flag. +var FlagStringer FlagStringFunc = stringifyFlag + // Flag is a common interface related to parsing flags in cli. // For more advanced flag parsing techniques, it is recommended that // this interface be implemented. @@ -68,24 +76,14 @@ type GenericFlag struct { Value Generic Usage string EnvVar string + Hidden bool } // String returns the string representation of the generic flag to display the // help text to the user (uses the String() method of the generic flag to show // the value) func (f GenericFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s %v\t%v", prefixedNames(f.Name), f.FormatValueHelp(), f.Usage)) -} - -func (f GenericFlag) FormatValueHelp() string { - if f.Value == nil { - return "" - } - s := f.Value.String() - if len(s) == 0 { - return "" - } - return fmt.Sprintf("\"%s\"", s) + return FlagStringer(f) } // Apply takes the flagset and calls Set on the generic flag with the value @@ -107,6 +105,7 @@ func (f GenericFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of a flag. func (f GenericFlag) GetName() string { return f.Name } @@ -130,20 +129,19 @@ func (f *StringSlice) Value() []string { return *f } -// StringSlice is a string flag that can be specified multiple times on the +// StringSliceFlag is a string flag that can be specified multiple times on the // command-line type StringSliceFlag struct { Name string Value *StringSlice Usage string EnvVar string + Hidden bool } // String returns the usage func (f StringSliceFlag) String() string { - firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") - pref := prefixFor(firstName) - return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -171,11 +169,12 @@ func (f StringSliceFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of a flag. func (f StringSliceFlag) GetName() string { return f.Name } -// StringSlice is an opaque type for []int to satisfy flag.Value +// IntSlice is an opaque type for []int to satisfy flag.Value type IntSlice []int // Set parses the value into an integer and appends it to the list of values @@ -183,9 +182,8 @@ func (f *IntSlice) Set(value string) error { tmp, err := strconv.Atoi(value) if err != nil { return err - } else { - *f = append(*f, tmp) } + *f = append(*f, tmp) return nil } @@ -206,13 +204,12 @@ type IntSliceFlag struct { Value *IntSlice Usage string EnvVar string + Hidden bool } // String returns the usage func (f IntSliceFlag) String() string { - firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") - pref := prefixFor(firstName) - return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -226,7 +223,7 @@ func (f IntSliceFlag) Apply(set *flag.FlagSet) { s = strings.TrimSpace(s) err := newVal.Set(s) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) + fmt.Fprintf(ErrWriter, err.Error()) } } f.Value = newVal @@ -243,6 +240,7 @@ func (f IntSliceFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of the flag. func (f IntSliceFlag) GetName() string { return f.Name } @@ -253,11 +251,12 @@ type BoolFlag struct { Usage string EnvVar string Destination *bool + Hidden bool } // String returns a readable representation of this value (for usage defaults) func (f BoolFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -285,6 +284,7 @@ func (f BoolFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of the flag. func (f BoolFlag) GetName() string { return f.Name } @@ -296,11 +296,12 @@ type BoolTFlag struct { Usage string EnvVar string Destination *bool + Hidden bool } // String returns a readable representation of this value (for usage defaults) func (f BoolTFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -328,6 +329,7 @@ func (f BoolTFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of the flag. func (f BoolTFlag) GetName() string { return f.Name } @@ -339,19 +341,12 @@ type StringFlag struct { Usage string EnvVar string Destination *string + Hidden bool } // String returns the usage func (f StringFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s %v\t%v", prefixedNames(f.Name), f.FormatValueHelp(), f.Usage)) -} - -func (f StringFlag) FormatValueHelp() string { - s := f.Value - if len(s) == 0 { - return "" - } - return fmt.Sprintf("\"%s\"", s) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -375,6 +370,7 @@ func (f StringFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of the flag. func (f StringFlag) GetName() string { return f.Name } @@ -387,11 +383,12 @@ type IntFlag struct { Usage string EnvVar string Destination *int + Hidden bool } // String returns the usage func (f IntFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -418,6 +415,7 @@ func (f IntFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of the flag. func (f IntFlag) GetName() string { return f.Name } @@ -430,11 +428,12 @@ type DurationFlag struct { Usage string EnvVar string Destination *time.Duration + Hidden bool } // String returns a readable representation of this value (for usage defaults) func (f DurationFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -461,6 +460,7 @@ func (f DurationFlag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of the flag. func (f DurationFlag) GetName() string { return f.Name } @@ -473,11 +473,12 @@ type Float64Flag struct { Usage string EnvVar string Destination *float64 + Hidden bool } // String returns the usage func (f Float64Flag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) + return FlagStringer(f) } // Apply populates the flag given the flag set and environment @@ -503,10 +504,21 @@ func (f Float64Flag) Apply(set *flag.FlagSet) { }) } +// GetName returns the name of the flag. func (f Float64Flag) GetName() string { return f.Name } +func visibleFlags(fl []Flag) []Flag { + visible := []Flag{} + for _, flag := range fl { + if !reflect.ValueOf(flag).FieldByName("Hidden").Bool() { + visible = append(visible, flag) + } + } + return visible +} + func prefixFor(name string) (prefix string) { if len(name) == 1 { prefix = "-" @@ -517,16 +529,37 @@ func prefixFor(name string) (prefix string) { return } -func prefixedNames(fullName string) (prefixed string) { +// Returns the placeholder, if any, and the unquoted usage string. +func unquoteUsage(usage string) (string, string) { + for i := 0; i < len(usage); i++ { + if usage[i] == '`' { + for j := i + 1; j < len(usage); j++ { + if usage[j] == '`' { + name := usage[i+1 : j] + usage = usage[:i] + name + usage[j+1:] + return name, usage + } + } + break + } + } + return "", usage +} + +func prefixedNames(fullName, placeholder string) string { + var prefixed string parts := strings.Split(fullName, ",") for i, name := range parts { name = strings.Trim(name, " ") prefixed += prefixFor(name) + name + if placeholder != "" { + prefixed += " " + placeholder + } if i < len(parts)-1 { prefixed += ", " } } - return + return prefixed } func withEnvHint(envVar, str string) string { @@ -544,3 +577,83 @@ func withEnvHint(envVar, str string) string { } return str + envText } + +func stringifyFlag(f Flag) string { + fv := reflect.ValueOf(f) + + switch f.(type) { + case IntSliceFlag: + return withEnvHint(fv.FieldByName("EnvVar").String(), + stringifyIntSliceFlag(f.(IntSliceFlag))) + case StringSliceFlag: + return withEnvHint(fv.FieldByName("EnvVar").String(), + stringifyStringSliceFlag(f.(StringSliceFlag))) + } + + placeholder, usage := unquoteUsage(fv.FieldByName("Usage").String()) + + needsPlaceholder := false + defaultValueString := "" + val := fv.FieldByName("Value") + + if val.IsValid() { + needsPlaceholder = true + defaultValueString = fmt.Sprintf(" (default: %v)", val.Interface()) + + if val.Kind() == reflect.String && val.String() != "" { + defaultValueString = fmt.Sprintf(" (default: %q)", val.String()) + } + } + + if defaultValueString == " (default: )" { + defaultValueString = "" + } + + if needsPlaceholder && placeholder == "" { + placeholder = defaultPlaceholder + } + + usageWithDefault := strings.TrimSpace(fmt.Sprintf("%s%s", usage, defaultValueString)) + + return withEnvHint(fv.FieldByName("EnvVar").String(), + fmt.Sprintf("%s\t%s", prefixedNames(fv.FieldByName("Name").String(), placeholder), usageWithDefault)) +} + +func stringifyIntSliceFlag(f IntSliceFlag) string { + defaultVals := []string{} + if f.Value != nil && len(f.Value.Value()) > 0 { + for _, i := range f.Value.Value() { + defaultVals = append(defaultVals, fmt.Sprintf("%d", i)) + } + } + + return stringifySliceFlag(f.Usage, f.Name, defaultVals) +} + +func stringifyStringSliceFlag(f StringSliceFlag) string { + defaultVals := []string{} + if f.Value != nil && len(f.Value.Value()) > 0 { + for _, s := range f.Value.Value() { + if len(s) > 0 { + defaultVals = append(defaultVals, fmt.Sprintf("%q", s)) + } + } + } + + return stringifySliceFlag(f.Usage, f.Name, defaultVals) +} + +func stringifySliceFlag(usage, name string, defaultVals []string) string { + placeholder, usage := unquoteUsage(usage) + if placeholder == "" { + placeholder = defaultPlaceholder + } + + defaultVal := "" + if len(defaultVals) > 0 { + defaultVal = fmt.Sprintf(" (default: %s)", strings.Join(defaultVals, ", ")) + } + + usageWithDefault := strings.TrimSpace(fmt.Sprintf("%s%s", usage, defaultVal)) + return fmt.Sprintf("%s\t%s", prefixedNames(name, placeholder), usageWithDefault) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/funcs.go b/Godeps/_workspace/src/github.com/codegangsta/cli/funcs.go new file mode 100644 index 0000000..cba5e6c --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/funcs.go @@ -0,0 +1,28 @@ +package cli + +// BashCompleteFunc is an action to execute when the bash-completion flag is set +type BashCompleteFunc func(*Context) + +// BeforeFunc is an action to execute before any subcommands are run, but after +// the context is ready if a non-nil error is returned, no subcommands are run +type BeforeFunc func(*Context) error + +// AfterFunc is an action to execute after any subcommands are run, but after the +// subcommand has finished it is run even if Action() panics +type AfterFunc func(*Context) error + +// ActionFunc is the action to execute when no subcommands are specified +type ActionFunc func(*Context) error + +// CommandNotFoundFunc is executed if the proper command cannot be found +type CommandNotFoundFunc func(*Context, string) + +// OnUsageErrorFunc is executed if an usage error occurs. This is useful for displaying +// customized usage error messages. This function is able to replace the +// original error messages. If this function is not set, the "Incorrect usage" +// is displayed and the execution is interrupted. +type OnUsageErrorFunc func(context *Context, err error, isSubcommand bool) error + +// FlagStringFunc is used by the help generation to display a flag, which is +// expected to be a single line. +type FlagStringFunc func(Flag) string diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/help.go b/Godeps/_workspace/src/github.com/codegangsta/cli/help.go index adf157d..a9e7327 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/help.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/help.go @@ -3,73 +3,74 @@ package cli import ( "fmt" "io" + "os" "strings" "text/tabwriter" "text/template" ) -// The text template for the Default help topic. +// AppHelpTemplate is the text template for the Default help topic. // cli.go uses text/template to render templates. You can // render custom help text by setting this variable. var AppHelpTemplate = `NAME: {{.Name}} - {{.Usage}} USAGE: - {{if .UsageText}}{{.UsageText}}{{else}}{{.HelpName}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}} + {{if .UsageText}}{{.UsageText}}{{else}}{{.HelpName}} {{if .VisibleFlags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}} {{if .Version}}{{if not .HideVersion}} VERSION: {{.Version}} {{end}}{{end}}{{if len .Authors}} AUTHOR(S): - {{range .Authors}}{{ . }}{{end}} - {{end}}{{if .Commands}} -COMMANDS:{{range .Categories}}{{if .Name}} - {{.Name}}{{ ":" }}{{end}}{{range .Commands}} - {{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}{{end}} -{{end}}{{end}}{{if .Flags}} + {{range .Authors}}{{.}}{{end}} + {{end}}{{if .VisibleCommands}} +COMMANDS:{{range .VisibleCategories}}{{if .Name}} + {{.Name}}:{{end}}{{range .VisibleCommands}} + {{join .Names ", "}}{{"\t"}}{{.Usage}}{{end}} +{{end}}{{end}}{{if .VisibleFlags}} GLOBAL OPTIONS: - {{range .Flags}}{{.}} - {{end}}{{end}}{{if .Copyright }} + {{range .VisibleFlags}}{{.}} + {{end}}{{end}}{{if .Copyright}} COPYRIGHT: {{.Copyright}} {{end}} ` -// The text template for the command help topic. +// CommandHelpTemplate is the text template for the command help topic. // cli.go uses text/template to render templates. You can // render custom help text by setting this variable. var CommandHelpTemplate = `NAME: {{.HelpName}} - {{.Usage}} USAGE: - {{.HelpName}}{{if .Flags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{if .Category}} + {{.HelpName}}{{if .VisibleFlags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{if .Category}} CATEGORY: {{.Category}}{{end}}{{if .Description}} DESCRIPTION: - {{.Description}}{{end}}{{if .Flags}} + {{.Description}}{{end}}{{if .VisibleFlags}} OPTIONS: - {{range .Flags}}{{.}} - {{end}}{{ end }} + {{range .VisibleFlags}}{{.}} + {{end}}{{end}} ` -// The text template for the subcommand help topic. +// SubcommandHelpTemplate is the text template for the subcommand help topic. // cli.go uses text/template to render templates. You can // render custom help text by setting this variable. var SubcommandHelpTemplate = `NAME: {{.HelpName}} - {{.Usage}} USAGE: - {{.HelpName}} command{{if .Flags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}} + {{.HelpName}} command{{if .VisibleFlags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}} -COMMANDS:{{range .Categories}}{{if .Name}} - {{.Name}}{{ ":" }}{{end}}{{range .Commands}} - {{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}{{end}} -{{end}}{{if .Flags}} +COMMANDS:{{range .VisibleCategories}}{{if .Name}} + {{.Name}}:{{end}}{{range .VisibleCommands}} + {{join .Names ", "}}{{"\t"}}{{.Usage}}{{end}} +{{end}}{{if .VisibleFlags}} OPTIONS: - {{range .Flags}}{{.}} + {{range .VisibleFlags}}{{.}} {{end}}{{end}} ` @@ -78,13 +79,14 @@ var helpCommand = Command{ Aliases: []string{"h"}, Usage: "Shows a list of commands or help for one command", ArgsUsage: "[command]", - Action: func(c *Context) { + Action: func(c *Context) error { args := c.Args() if args.Present() { - ShowCommandHelp(c, args.First()) - } else { - ShowAppHelp(c) + return ShowCommandHelp(c, args.First()) } + + ShowAppHelp(c) + return nil }, } @@ -93,65 +95,73 @@ var helpSubcommand = Command{ Aliases: []string{"h"}, Usage: "Shows a list of commands or help for one command", ArgsUsage: "[command]", - Action: func(c *Context) { + Action: func(c *Context) error { args := c.Args() if args.Present() { - ShowCommandHelp(c, args.First()) - } else { - ShowSubcommandHelp(c) + return ShowCommandHelp(c, args.First()) } + + return ShowSubcommandHelp(c) }, } // Prints help for the App or Command type helpPrinter func(w io.Writer, templ string, data interface{}) +// HelpPrinter is a function that writes the help output. If not set a default +// is used. The function signature is: +// func(w io.Writer, templ string, data interface{}) var HelpPrinter helpPrinter = printHelp -// Prints version for the App +// VersionPrinter prints the version for the App var VersionPrinter = printVersion +// ShowAppHelp is an action that displays the help. func ShowAppHelp(c *Context) { HelpPrinter(c.App.Writer, AppHelpTemplate, c.App) } -// Prints the list of subcommands as the default app completion method +// DefaultAppComplete prints the list of subcommands as the default app completion method func DefaultAppComplete(c *Context) { for _, command := range c.App.Commands { + if command.Hidden { + continue + } for _, name := range command.Names() { fmt.Fprintln(c.App.Writer, name) } } } -// Prints help for the given command -func ShowCommandHelp(ctx *Context, command string) { +// ShowCommandHelp prints help for the given command +func ShowCommandHelp(ctx *Context, command string) error { // show the subcommand help for a command with subcommands if command == "" { HelpPrinter(ctx.App.Writer, SubcommandHelpTemplate, ctx.App) - return + return nil } for _, c := range ctx.App.Commands { if c.HasName(command) { HelpPrinter(ctx.App.Writer, CommandHelpTemplate, c) - return + return nil } } - if ctx.App.CommandNotFound != nil { - ctx.App.CommandNotFound(ctx, command) - } else { - fmt.Fprintf(ctx.App.Writer, "No help topic for '%v'\n", command) + if ctx.App.CommandNotFound == nil { + return NewExitError(fmt.Sprintf("No help topic for '%v'", command), 3) } + + ctx.App.CommandNotFound(ctx, command) + return nil } -// Prints help for the given subcommand -func ShowSubcommandHelp(c *Context) { - ShowCommandHelp(c, c.Command.Name) +// ShowSubcommandHelp prints help for the given subcommand +func ShowSubcommandHelp(c *Context) error { + return ShowCommandHelp(c, c.Command.Name) } -// Prints the version number of the App +// ShowVersion prints the version number of the App func ShowVersion(c *Context) { VersionPrinter(c) } @@ -160,7 +170,7 @@ func printVersion(c *Context) { fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version) } -// Prints the lists of commands within a given context +// ShowCompletions prints the lists of commands within a given context func ShowCompletions(c *Context) { a := c.App if a != nil && a.BashComplete != nil { @@ -168,7 +178,7 @@ func ShowCompletions(c *Context) { } } -// Prints the custom completions for a given command +// ShowCommandCompletions prints the custom completions for a given command func ShowCommandCompletions(ctx *Context, command string) { c := ctx.App.Command(command) if c != nil && c.BashComplete != nil { @@ -186,7 +196,10 @@ func printHelp(out io.Writer, templ string, data interface{}) { err := t.Execute(w, data) if err != nil { // If the writer is closed, t.Execute will fail, and there's nothing - // we can do to recover. We could send this to os.Stderr if we need. + // we can do to recover. + if os.Getenv("CLI_TEMPLATE_ERROR_DEBUG") != "" { + fmt.Fprintf(ErrWriter, "CLI TEMPLATE ERROR: %#v\n", err) + } return } w.Flush() diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/runtests b/Godeps/_workspace/src/github.com/codegangsta/cli/runtests new file mode 100644 index 0000000..feacff3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/runtests @@ -0,0 +1,95 @@ +#!/usr/bin/env python +from __future__ import print_function + +import argparse +import os +import sys +import tempfile + +from subprocess import check_call, check_output + + +PACKAGE_NAME = os.environ.get( + 'CLI_PACKAGE_NAME', 'github.com/codegangsta/cli' +) + + +def main(sysargs=sys.argv[:]): + targets = { + 'vet': _vet, + 'test': _test, + 'gfmxr': _gfmxr + } + + parser = argparse.ArgumentParser() + parser.add_argument( + 'target', nargs='?', choices=tuple(targets.keys()), default='test' + ) + args = parser.parse_args(sysargs[1:]) + + targets[args.target]() + return 0 + + +def _test(): + if check_output('go version'.split()).split()[2] < 'go1.2': + _run('go test -v ./...'.split()) + return + + coverprofiles = [] + for subpackage in ['', 'altsrc']: + coverprofile = 'cli.coverprofile' + if subpackage != '': + coverprofile = '{}.coverprofile'.format(subpackage) + + coverprofiles.append(coverprofile) + + _run('go test -v'.split() + [ + '-coverprofile={}'.format(coverprofile), + ('{}/{}'.format(PACKAGE_NAME, subpackage)).rstrip('/') + ]) + + combined = _combine_coverprofiles(coverprofiles) + _run('go tool cover -func={}'.format(combined.name).split()) + combined.close() + + +def _gfmxr(): + _run(['gfmxr', '-c', str(_gfmxr_count()), '-s', 'README.md']) + + +def _vet(): + _run('go vet ./...'.split()) + + +def _run(command): + print('runtests: {}'.format(' '.join(command)), file=sys.stderr) + check_call(command) + + +def _gfmxr_count(): + with open('README.md') as infile: + lines = infile.read().splitlines() + return len(filter(_is_go_runnable, lines)) + + +def _is_go_runnable(line): + return line.startswith('package main') + + +def _combine_coverprofiles(coverprofiles): + combined = tempfile.NamedTemporaryFile(suffix='.coverprofile') + combined.write('mode: set\n') + + for coverprofile in coverprofiles: + with open(coverprofile, 'r') as infile: + for line in infile.readlines(): + if not line.startswith('mode: '): + combined.write(line) + + combined.flush() + return combined + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/provider.go b/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/provider.go index 1235890..dcae4c9 100644 --- a/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/provider.go +++ b/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/provider.go @@ -537,7 +537,7 @@ func (s *ProviderConfigSyncer) sync() (time.Duration, error) { s.initialSyncDone = true } - log.Infof("Updating provider config: config=%#v", cfg) + log.Debugf("Updating provider config: config=%#v", cfg) return nextSyncAfter(cfg.ExpiresAt, s.clock), nil } diff --git a/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/transport.go b/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/transport.go index 93ff9e1..61c926d 100644 --- a/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/transport.go +++ b/Godeps/_workspace/src/github.com/coreos/go-oidc/oidc/transport.go @@ -67,6 +67,15 @@ func (t *AuthenticatedTransport) verifiedJWT() (jose.JWT, error) { return t.jwt, nil } +// SetJWT sets the JWT held by the Transport. +// This is useful for cases in which you want to set an initial JWT. +func (t *AuthenticatedTransport) SetJWT(jwt jose.JWT) { + t.mu.Lock() + defer t.mu.Unlock() + + t.jwt = jwt +} + func (t *AuthenticatedTransport) RoundTrip(r *http.Request) (*http.Response, error) { jwt, err := t.verifiedJWT() if err != nil { diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/.gitignore b/Godeps/_workspace/src/github.com/elazarl/goproxy/.gitignore new file mode 100644 index 0000000..1005f6f --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/.gitignore @@ -0,0 +1,2 @@ +bin +*.swp diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/LICENSE b/Godeps/_workspace/src/github.com/elazarl/goproxy/LICENSE new file mode 100644 index 0000000..2067e56 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 Elazar Leibovich. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Elazar Leibovich. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/README.md b/Godeps/_workspace/src/github.com/elazarl/goproxy/README.md new file mode 100644 index 0000000..347993f --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/README.md @@ -0,0 +1,118 @@ +# Introduction + +[](https://godoc.org/github.com/elazarl/goproxy) +[](https://gitter.im/elazarl/goproxy?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +Package goproxy provides a customizable HTTP proxy library for Go (golang), + +It supports regular HTTP proxy, HTTPS through CONNECT, and "hijacking" HTTPS +connection using "Man in the Middle" style attack. + +The intent of the proxy, is to be usable with reasonable amount of traffic +yet, customizable and programable. + +The proxy itself is simply a `net/http` handler. + +In order to use goproxy, one should set their browser to use goproxy as an HTTP +proxy. Here is how you do that [in Chrome](https://support.google.com/chrome/answer/96815?hl=en) +and [in Firefox](http://www.wikihow.com/Enter-Proxy-Settings-in-Firefox). + +For example, the URL you should use as proxy when running `./bin/basic` is +`localhost:8080`, as this is the default binding for the basic proxy. + +## Mailing List + +New features would be discussed on the [mailing list](https://groups.google.com/forum/#!forum/goproxy-dev) +before their development. + +## Latest Stable Release + +Get the latest goproxy from `gopkg.in/elazarl/goproxy.v1`. + +# Why not Fiddler2? + +Fiddler is an excellent software with similar intent. However, Fiddler is not +as customable as goproxy intend to be. The main difference is, Fiddler is not +intended to be used as a real proxy. + +A possible use case that suits goproxy but +not Fiddler, is, gathering statisitics on page load times for a certain website over a week. +With goproxy you could ask all your users to set their proxy to a dedicated machine running a +goproxy server. Fiddler is a GUI app not designed to be ran like a server for multiple users. + +# A taste of goproxy + +To get a taste of `goproxy`, a basic HTTP/HTTPS transparent proxy + + + package main + + import ( + "github.com/elazarl/goproxy" + "log" + "net/http" + ) + + func main() { + proxy := goproxy.NewProxyHttpServer() + proxy.Verbose = true + log.Fatal(http.ListenAndServe(":8080", proxy)) + } + + +This line will add `X-GoProxy: yxorPoG-X` header to all requests sent through the proxy + + proxy.OnRequest().DoFunc( + func(r *http.Request,ctx *goproxy.ProxyCtx)(*http.Request,*http.Response) { + r.Header.Set("X-GoProxy","yxorPoG-X") + return r,nil + }) + +`DoFunc` will process all incoming requests to the proxy. It will add a header to the request +and return it. The proxy will send the modified request. + +Note that we returned nil value as the response. Have we returned a response, goproxy would +have discarded the request and sent the new response to the client. + +In order to refuse connections to reddit at work time + + proxy.OnRequest(goproxy.DstHostIs("www.reddit.com")).DoFunc( + func(r *http.Request,ctx *goproxy.ProxyCtx)(*http.Request,*http.Response) { + if h,_,_ := time.Now().Clock(); h >= 8 && h <= 17 { + return r,goproxy.NewResponse(r, + goproxy.ContentTypeText,http.StatusForbidden, + "Don't waste your time!") + } + return r,nil + }) + +`DstHostIs` returns a `ReqCondition`, that is a function receiving a `Request` and returning a boolean +we will only process requests that matches the condition. `DstHostIs("www.reddit.com")` will return +a `ReqCondition` accepting only requests directed to "www.reddit.com". + +`DoFunc` will recieve a function that will preprocess the request. We can change the request, or +return a response. If the time is between 8:00am and 17:00pm, we will neglect the request, and +return a precanned text response saying "do not waste your time". + +See additional examples in the examples directory. + +# What's New + + 1. Ability to `Hijack` CONNECT requests. See +[the eavesdropper example](https://github.com/elazarl/goproxy/blob/master/examples/goproxy-eavesdropper/main.go#L27) +2. Transparent proxy support for http/https including MITM certificate generation for TLS. See the [transparent example.](https://github.com/elazarl/goproxy/tree/master/examples/goproxy-transparent) + +# License + +I put the software temporarily under the Go-compatible BSD license, +if this prevents someone from using the software, do let mee know and I'll consider changing it. + +At any rate, user feedback is very important for me, so I'll be delighted to know if you're using this package. + +# Beta Software + +I've received a positive feedback from a few people who use goproxy in production settings. +I believe it is good enough for usage. + +I'll try to keep reasonable backwards compatability. In case of a major API change, +I'll change the import path. diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/actions.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/actions.go new file mode 100644 index 0000000..e1a3e7f --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/actions.go @@ -0,0 +1,57 @@ +package goproxy + +import "net/http" + +// ReqHandler will "tamper" with the request coming to the proxy server +// If Handle returns req,nil the proxy will send the returned request +// to the destination server. If it returns nil,resp the proxy will +// skip sending any requests, and will simply return the response `resp` +// to the client. +type ReqHandler interface { + Handle(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response) +} + +// A wrapper that would convert a function to a ReqHandler interface type +type FuncReqHandler func(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response) + +// FuncReqHandler.Handle(req,ctx) <=> FuncReqHandler(req,ctx) +func (f FuncReqHandler) Handle(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response) { + return f(req, ctx) +} + +// after the proxy have sent the request to the destination server, it will +// "filter" the response through the RespHandlers it has. +// The proxy server will send to the client the response returned by the RespHandler. +// In case of error, resp will be nil, and ctx.RoundTrip.Error will contain the error +type RespHandler interface { + Handle(resp *http.Response, ctx *ProxyCtx) *http.Response +} + +// A wrapper that would convert a function to a RespHandler interface type +type FuncRespHandler func(resp *http.Response, ctx *ProxyCtx) *http.Response + +// FuncRespHandler.Handle(req,ctx) <=> FuncRespHandler(req,ctx) +func (f FuncRespHandler) Handle(resp *http.Response, ctx *ProxyCtx) *http.Response { + return f(resp, ctx) +} + +// When a client send a CONNECT request to a host, the request is filtered through +// all the HttpsHandlers the proxy has, and if one returns true, the connection is +// sniffed using Man in the Middle attack. +// That is, the proxy will create a TLS connection with the client, another TLS +// connection with the destination the client wished to connect to, and would +// send back and forth all messages from the server to the client and vice versa. +// The request and responses sent in this Man In the Middle channel are filtered +// through the usual flow (request and response filtered through the ReqHandlers +// and RespHandlers) +type HttpsHandler interface { + HandleConnect(req string, ctx *ProxyCtx) (*ConnectAction, string) +} + +// A wrapper that would convert a function to a HttpsHandler interface type +type FuncHttpsHandler func(host string, ctx *ProxyCtx) (*ConnectAction, string) + +// FuncHttpsHandler should implement the RespHandler interface +func (f FuncHttpsHandler) HandleConnect(host string, ctx *ProxyCtx) (*ConnectAction, string) { + return f(host, ctx) +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/all.bash b/Godeps/_workspace/src/github.com/elazarl/goproxy/all.bash new file mode 100644 index 0000000..6503e73 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/all.bash @@ -0,0 +1,15 @@ +#!/bin/bash + +go test || exit +for action in $@; do go $action; done + +mkdir -p bin +find regretable examples/* ext/* -maxdepth 0 -type d | while read d; do + (cd $d + go build -o ../../bin/$(basename $d) + find *_test.go -maxdepth 0 2>/dev/null|while read f;do + for action in $@; do go $action; done + go test + break + done) +done diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/ca.pem b/Godeps/_workspace/src/github.com/elazarl/goproxy/ca.pem new file mode 100644 index 0000000..f138424 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/ca.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICSjCCAbWgAwIBAgIBADALBgkqhkiG9w0BAQUwSjEjMCEGA1UEChMaZ2l0aHVi +LmNvbS9lbGF6YXJsL2dvcHJveHkxIzAhBgNVBAMTGmdpdGh1Yi5jb20vZWxhemFy +bC9nb3Byb3h5MB4XDTAwMDEwMTAwMDAwMFoXDTQ5MTIzMTIzNTk1OVowSjEjMCEG +A1UEChMaZ2l0aHViLmNvbS9lbGF6YXJsL2dvcHJveHkxIzAhBgNVBAMTGmdpdGh1 +Yi5jb20vZWxhemFybC9nb3Byb3h5MIGdMAsGCSqGSIb3DQEBAQOBjQAwgYkCgYEA +vz9BbCaJjxs73Tvcq3leP32hAGerQ1RgvlZ68Z4nZmoVHfl+2Nr/m0dmW+GdOfpT +cs/KzfJjYGr/84x524fiuR8GdZ0HOtXJzyF5seoWnbBIuyr1PbEpgRhGQMqqOUuj +YExeLbfNHPIoJ8XZ1Vzyv3YxjbmjWA+S/uOe9HWtDbMCAwEAAaNGMEQwDgYDVR0P +AQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8w +DAYDVR0RBAUwA4IBKjALBgkqhkiG9w0BAQUDgYEAIcL8huSmGMompNujsvePTUnM +oEUKtX4Eh/+s+DSfV/TyI0I+3GiPpLplEgFWuoBIJGios0r1dKh5N0TGjxX/RmGm +qo7E4jjJuo8Gs5U8/fgThZmshax2lwLtbRNwhvUVr65GdahLsZz8I+hySLuatVvR +qHHq/FQORIiNyNpq/Hg= +-----END CERTIFICATE----- diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/certs.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/certs.go new file mode 100644 index 0000000..8da2e62 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/certs.go @@ -0,0 +1,56 @@ +package goproxy + +import ( + "crypto/tls" + "crypto/x509" +) + +func init() { + if goproxyCaErr != nil { + panic("Error parsing builtin CA " + goproxyCaErr.Error()) + } + var err error + if GoproxyCa.Leaf, err = x509.ParseCertificate(GoproxyCa.Certificate[0]); err != nil { + panic("Error parsing builtin CA " + err.Error()) + } +} + +var tlsClientSkipVerify = &tls.Config{InsecureSkipVerify: true} + +var defaultTLSConfig = &tls.Config{ + InsecureSkipVerify: true, +} + +var CA_CERT = []byte(`-----BEGIN CERTIFICATE----- +MIICSjCCAbWgAwIBAgIBADALBgkqhkiG9w0BAQUwSjEjMCEGA1UEChMaZ2l0aHVi +LmNvbS9lbGF6YXJsL2dvcHJveHkxIzAhBgNVBAMTGmdpdGh1Yi5jb20vZWxhemFy +bC9nb3Byb3h5MB4XDTAwMDEwMTAwMDAwMFoXDTQ5MTIzMTIzNTk1OVowSjEjMCEG +A1UEChMaZ2l0aHViLmNvbS9lbGF6YXJsL2dvcHJveHkxIzAhBgNVBAMTGmdpdGh1 +Yi5jb20vZWxhemFybC9nb3Byb3h5MIGdMAsGCSqGSIb3DQEBAQOBjQAwgYkCgYEA +vz9BbCaJjxs73Tvcq3leP32hAGerQ1RgvlZ68Z4nZmoVHfl+2Nr/m0dmW+GdOfpT +cs/KzfJjYGr/84x524fiuR8GdZ0HOtXJzyF5seoWnbBIuyr1PbEpgRhGQMqqOUuj +YExeLbfNHPIoJ8XZ1Vzyv3YxjbmjWA+S/uOe9HWtDbMCAwEAAaNGMEQwDgYDVR0P +AQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8w +DAYDVR0RBAUwA4IBKjALBgkqhkiG9w0BAQUDgYEAIcL8huSmGMompNujsvePTUnM +oEUKtX4Eh/+s+DSfV/TyI0I+3GiPpLplEgFWuoBIJGios0r1dKh5N0TGjxX/RmGm +qo7E4jjJuo8Gs5U8/fgThZmshax2lwLtbRNwhvUVr65GdahLsZz8I+hySLuatVvR +qHHq/FQORIiNyNpq/Hg= +-----END CERTIFICATE-----`) + +var CA_KEY = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQC/P0FsJomPGzvdO9yreV4/faEAZ6tDVGC+VnrxnidmahUd+X7Y +2v+bR2Zb4Z05+lNyz8rN8mNgav/zjHnbh+K5HwZ1nQc61cnPIXmx6hadsEi7KvU9 +sSmBGEZAyqo5S6NgTF4tt80c8ignxdnVXPK/djGNuaNYD5L+4570da0NswIDAQAB +AoGBALzIv1b4D7ARTR3NOr6V9wArjiOtMjUrdLhO+9vIp9IEA8ZsA9gjDlCEwbkP +VDnoLjnWfraff5Os6+3JjHy1fYpUiCdnk2XA6iJSL1XWKQZPt3wOunxP4lalDgED +QTRReFbA/y/Z4kSfTXpVj68ytcvSRW/N7q5/qRtbN9804jpBAkEA0s6lvH2btSLA +mcEdwhs7zAslLbdld7rvfUeP82gPPk0S6yUqTNyikqshM9AwAktHY7WvYdKl+ghZ +HTxKVC4DoQJBAOg/IAW5RbXknP+Lf7AVtBgw3E+Yfa3mcdLySe8hjxxyZq825Zmu +Rt5Qj4Lw6ifSFNy4kiiSpE/ZCukYvUXGENMCQFkPxSWlS6tzSzuqQxBGwTSrYMG3 +wb6b06JyIXcMd6Qym9OMmBpw/J5KfnSNeDr/4uFVWQtTG5xO+pdHaX+3EQECQQDl +qcbY4iX1gWVfr2tNjajSYz751yoxVbkpiT9joiQLVXYFvpu+JYEfRzsjmWl0h2Lq +AftG8/xYmaEYcMZ6wSrRAkBUwiom98/8wZVlB6qbwhU1EKDFANvICGSWMIhPx3v7 +MJqTIj4uJhte2/uyVvZ6DC6noWYgy+kLgqG0S97tUEG8 +-----END RSA PRIVATE KEY-----`) + +var GoproxyCa, goproxyCaErr = tls.X509KeyPair(CA_CERT, CA_KEY) diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/chunked.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/chunked.go new file mode 100644 index 0000000..83654f6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/chunked.go @@ -0,0 +1,59 @@ +// Taken from $GOROOT/src/pkg/net/http/chunked +// needed to write https responses to client. +package goproxy + +import ( + "io" + "strconv" +) + +// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// newChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func newChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" + + if _, err = io.WriteString(cw.Wire, head); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + _, err = io.WriteString(cw.Wire, "\r\n") + + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/counterecryptor.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/counterecryptor.go new file mode 100644 index 0000000..494e7a4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/counterecryptor.go @@ -0,0 +1,68 @@ +package goproxy + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "errors" +) + +type CounterEncryptorRand struct { + cipher cipher.Block + counter []byte + rand []byte + ix int +} + +func NewCounterEncryptorRandFromKey(key interface{}, seed []byte) (r CounterEncryptorRand, err error) { + var keyBytes []byte + switch key := key.(type) { + case *rsa.PrivateKey: + keyBytes = x509.MarshalPKCS1PrivateKey(key) + default: + err = errors.New("only RSA keys supported") + return + } + h := sha256.New() + if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil { + return + } + r.counter = make([]byte, r.cipher.BlockSize()) + if seed != nil { + copy(r.counter, h.Sum(seed)[:r.cipher.BlockSize()]) + } + r.rand = make([]byte, r.cipher.BlockSize()) + r.ix = len(r.rand) + return +} + +func (c *CounterEncryptorRand) Seed(b []byte) { + if len(b) != len(c.counter) { + panic("SetCounter: wrong counter size") + } + copy(c.counter, b) +} + +func (c *CounterEncryptorRand) refill() { + c.cipher.Encrypt(c.rand, c.counter) + for i := 0; i < len(c.counter); i++ { + if c.counter[i]++; c.counter[i] != 0 { + break + } + } + c.ix = 0 +} + +func (c *CounterEncryptorRand) Read(b []byte) (n int, err error) { + if c.ix == len(c.rand) { + c.refill() + } + if n = len(c.rand) - c.ix; n > len(b) { + n = len(b) + } + copy(b, c.rand[c.ix:c.ix+n]) + c.ix += n + return +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/ctx.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/ctx.go new file mode 100644 index 0000000..95bfd80 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/ctx.go @@ -0,0 +1,87 @@ +package goproxy + +import ( + "net/http" + "regexp" +) + +// ProxyCtx is the Proxy context, contains useful information about every request. It is passed to +// every user function. Also used as a logger. +type ProxyCtx struct { + // Will contain the client request from the proxy + Req *http.Request + // Will contain the remote server's response (if available. nil if the request wasn't send yet) + Resp *http.Response + RoundTripper RoundTripper + // will contain the recent error that occured while trying to send receive or parse traffic + Error error + // A handle for the user to keep data in the context, from the call of ReqHandler to the + // call of RespHandler + UserData interface{} + // Will connect a request to a response + Session int64 + proxy *ProxyHttpServer +} + +type RoundTripper interface { + RoundTrip(req *http.Request, ctx *ProxyCtx) (*http.Response, error) +} + +type RoundTripperFunc func(req *http.Request, ctx *ProxyCtx) (*http.Response, error) + +func (f RoundTripperFunc) RoundTrip(req *http.Request, ctx *ProxyCtx) (*http.Response, error) { + return f(req, ctx) +} + +func (ctx *ProxyCtx) RoundTrip(req *http.Request) (*http.Response, error) { + if ctx.RoundTripper != nil { + return ctx.RoundTripper.RoundTrip(req, ctx) + } + return ctx.proxy.Tr.RoundTrip(req) +} + +func (ctx *ProxyCtx) printf(msg string, argv ...interface{}) { + ctx.proxy.Logger.Printf("[%03d] "+msg+"\n", append([]interface{}{ctx.Session & 0xFF}, argv...)...) +} + +// Logf prints a message to the proxy's log. Should be used in a ProxyHttpServer's filter +// This message will be printed only if the Verbose field of the ProxyHttpServer is set to true +// +// proxy.OnRequest().DoFunc(func(r *http.Request,ctx *goproxy.ProxyCtx) (*http.Request, *http.Response){ +// nr := atomic.AddInt32(&counter,1) +// ctx.Printf("So far %d requests",nr) +// return r, nil +// }) +func (ctx *ProxyCtx) Logf(msg string, argv ...interface{}) { + if ctx.proxy.Verbose { + ctx.printf("INFO: "+msg, argv...) + } +} + +// Warnf prints a message to the proxy's log. Should be used in a ProxyHttpServer's filter +// This message will always be printed. +// +// proxy.OnRequest().DoFunc(func(r *http.Request,ctx *goproxy.ProxyCtx) (*http.Request, *http.Response){ +// f,err := os.OpenFile(cachedContent) +// if err != nil { +// ctx.Warnf("error open file %v: %v",cachedContent,err) +// return r, nil +// } +// return r, nil +// }) +func (ctx *ProxyCtx) Warnf(msg string, argv ...interface{}) { + ctx.printf("WARN: "+msg, argv...) +} + +var charsetFinder = regexp.MustCompile("charset=([^ ;]*)") + +// Will try to infer the character set of the request from the headers. +// Returns the empty string if we don't know which character set it used. +// Currently it will look for charset=<charset> in the Content-Type header of the request. +func (ctx *ProxyCtx) Charset() string { + charsets := charsetFinder.FindStringSubmatch(ctx.Resp.Header.Get("Content-Type")) + if charsets == nil { + return "" + } + return charsets[1] +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/dispatcher.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/dispatcher.go new file mode 100644 index 0000000..4e7c9cb --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/dispatcher.go @@ -0,0 +1,325 @@ +package goproxy + +import ( + "bytes" + "io/ioutil" + "net" + "net/http" + "regexp" + "strings" +) + +// ReqCondition.HandleReq will decide whether or not to use the ReqHandler on an HTTP request +// before sending it to the remote server +type ReqCondition interface { + RespCondition + HandleReq(req *http.Request, ctx *ProxyCtx) bool +} + +// RespCondition.HandleReq will decide whether or not to use the RespHandler on an HTTP response +// before sending it to the proxy client. Note that resp might be nil, in case there was an +// error sending the request. +type RespCondition interface { + HandleResp(resp *http.Response, ctx *ProxyCtx) bool +} + +// ReqConditionFunc.HandleReq(req,ctx) <=> ReqConditionFunc(req,ctx) +type ReqConditionFunc func(req *http.Request, ctx *ProxyCtx) bool + +// RespConditionFunc.HandleResp(resp,ctx) <=> RespConditionFunc(resp,ctx) +type RespConditionFunc func(resp *http.Response, ctx *ProxyCtx) bool + +func (c ReqConditionFunc) HandleReq(req *http.Request, ctx *ProxyCtx) bool { + return c(req, ctx) +} + +// ReqConditionFunc cannot test responses. It only satisfies RespCondition interface so that +// to be usable as RespCondition. +func (c ReqConditionFunc) HandleResp(resp *http.Response, ctx *ProxyCtx) bool { + return c(ctx.Req, ctx) +} + +func (c RespConditionFunc) HandleResp(resp *http.Response, ctx *ProxyCtx) bool { + return c(resp, ctx) +} + +// UrlHasPrefix returns a ReqCondition checking wether the destination URL the proxy client has requested +// has the given prefix, with or without the host. +// For example UrlHasPrefix("host/x") will match requests of the form 'GET host/x', and will match +// requests to url 'http://host/x' +func UrlHasPrefix(prefix string) ReqConditionFunc { + return func(req *http.Request, ctx *ProxyCtx) bool { + return strings.HasPrefix(req.URL.Path, prefix) || + strings.HasPrefix(req.URL.Host+req.URL.Path, prefix) || + strings.HasPrefix(req.URL.Scheme+req.URL.Host+req.URL.Path, prefix) + } +} + +// UrlIs returns a ReqCondition, testing whether or not the request URL is one of the given strings +// with or without the host prefix. +// UrlIs("google.com/","foo") will match requests 'GET /' to 'google.com', requests `'GET google.com/' to +// any host, and requests of the form 'GET foo'. +func UrlIs(urls ...string) ReqConditionFunc { + urlSet := make(map[string]bool) + for _, u := range urls { + urlSet[u] = true + } + return func(req *http.Request, ctx *ProxyCtx) bool { + _, pathOk := urlSet[req.URL.Path] + _, hostAndOk := urlSet[req.URL.Host+req.URL.Path] + return pathOk || hostAndOk + } +} + +// ReqHostMatches returns a ReqCondition, testing whether the host to which the request was directed to matches +// any of the given regular expressions. +func ReqHostMatches(regexps ...*regexp.Regexp) ReqConditionFunc { + return func(req *http.Request, ctx *ProxyCtx) bool { + for _, re := range regexps { + if re.MatchString(req.Host) { + return true + } + } + return false + } +} + +// ReqHostIs returns a ReqCondition, testing whether the host to which the request is directed to equal +// to one of the given strings +func ReqHostIs(hosts ...string) ReqConditionFunc { + hostSet := make(map[string]bool) + for _, h := range hosts { + hostSet[h] = true + } + return func(req *http.Request, ctx *ProxyCtx) bool { + _, ok := hostSet[req.URL.Host] + return ok + } +} + +var localHostIpv4 = regexp.MustCompile(`127\.0\.0\.\d+`) + +// IsLocalHost checks whether the destination host is explicitly local host +// (buggy, there can be IPv6 addresses it doesn't catch) +var IsLocalHost ReqConditionFunc = func(req *http.Request, ctx *ProxyCtx) bool { + return req.URL.Host == "::1" || + req.URL.Host == "0:0:0:0:0:0:0:1" || + localHostIpv4.MatchString(req.URL.Host) || + req.URL.Host == "localhost" +} + +// UrlMatches returns a ReqCondition testing whether the destination URL +// of the request matches the given regexp, with or without prefix +func UrlMatches(re *regexp.Regexp) ReqConditionFunc { + return func(req *http.Request, ctx *ProxyCtx) bool { + return re.MatchString(req.URL.Path) || + re.MatchString(req.URL.Host+req.URL.Path) + } +} + +// DstHostIs returns a ReqCondition testing wether the host in the request url is the given string +func DstHostIs(host string) ReqConditionFunc { + return func(req *http.Request, ctx *ProxyCtx) bool { + return req.URL.Host == host + } +} + +// SrcIpIs returns a ReqCondition testing whether the source IP of the request is one of the given strings +func SrcIpIs(ips ...string) ReqCondition { + return ReqConditionFunc(func(req *http.Request, ctx *ProxyCtx) bool { + for _, ip := range ips { + if strings.HasPrefix(req.RemoteAddr, ip+":") { + return true + } + } + return false + }) +} + +// Not returns a ReqCondition negating the given ReqCondition +func Not(r ReqCondition) ReqConditionFunc { + return func(req *http.Request, ctx *ProxyCtx) bool { + return !r.HandleReq(req, ctx) + } +} + +// ContentTypeIs returns a RespCondition testing whether the HTTP response has Content-Type header equal +// to one of the given strings. +func ContentTypeIs(typ string, types ...string) RespCondition { + types = append(types, typ) + return RespConditionFunc(func(resp *http.Response, ctx *ProxyCtx) bool { + if resp == nil { + return false + } + contentType := resp.Header.Get("Content-Type") + for _, typ := range types { + if contentType == typ || strings.HasPrefix(contentType, typ+";") { + return true + } + } + return false + }) +} + +// ProxyHttpServer.OnRequest Will return a temporary ReqProxyConds struct, aggregating the given condtions. +// You will use the ReqProxyConds struct to register a ReqHandler, that would filter +// the request, only if all the given ReqCondition matched. +// Typical usage: +// proxy.OnRequest(UrlIs("example.com/foo"),UrlMatches(regexp.MustParse(`.*\.exampl.\com\./.*`)).Do(...) +func (proxy *ProxyHttpServer) OnRequest(conds ...ReqCondition) *ReqProxyConds { + return &ReqProxyConds{proxy, conds} +} + +// ReqProxyConds aggregate ReqConditions for a ProxyHttpServer. Upon calling Do, it will register a ReqHandler that would +// handle the request if all conditions on the HTTP request are met. +type ReqProxyConds struct { + proxy *ProxyHttpServer + reqConds []ReqCondition +} + +// DoFunc is equivalent to proxy.OnRequest().Do(FuncReqHandler(f)) +func (pcond *ReqProxyConds) DoFunc(f func(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response)) { + pcond.Do(FuncReqHandler(f)) +} + +// ReqProxyConds.Do will register the ReqHandler on the proxy, +// the ReqHandler will handle the HTTP request if all the conditions +// aggregated in the ReqProxyConds are met. Typical usage: +// proxy.OnRequest().Do(handler) // will call handler.Handle(req,ctx) on every request to the proxy +// proxy.OnRequest(cond1,cond2).Do(handler) +// // given request to the proxy, will test if cond1.HandleReq(req,ctx) && cond2.HandleReq(req,ctx) are true +// // if they are, will call handler.Handle(req,ctx) +func (pcond *ReqProxyConds) Do(h ReqHandler) { + pcond.proxy.reqHandlers = append(pcond.proxy.reqHandlers, + FuncReqHandler(func(r *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response) { + for _, cond := range pcond.reqConds { + if !cond.HandleReq(r, ctx) { + return r, nil + } + } + return h.Handle(r, ctx) + })) +} + +// HandleConnect is used when proxy receives an HTTP CONNECT request, +// it'll then use the HttpsHandler to determine what should it +// do with this request. The handler returns a ConnectAction struct, the Action field in the ConnectAction +// struct returned will determine what to do with this request. ConnectAccept will simply accept the request +// forwarding all bytes from the client to the remote host, ConnectReject will close the connection with the +// client, and ConnectMitm, will assume the underlying connection is an HTTPS connection, and will use Man +// in the Middle attack to eavesdrop the connection. All regular handler will be active on this eavesdropped +// connection. +// The ConnectAction struct contains possible tlsConfig that will be used for eavesdropping. If nil, the proxy +// will use the default tls configuration. +// proxy.OnRequest().HandleConnect(goproxy.AlwaysReject) // rejects all CONNECT requests +func (pcond *ReqProxyConds) HandleConnect(h HttpsHandler) { + pcond.proxy.httpsHandlers = append(pcond.proxy.httpsHandlers, + FuncHttpsHandler(func(host string, ctx *ProxyCtx) (*ConnectAction, string) { + for _, cond := range pcond.reqConds { + if !cond.HandleReq(ctx.Req, ctx) { + return nil, "" + } + } + return h.HandleConnect(host, ctx) + })) +} + +// HandleConnectFunc is equivalent to HandleConnect, +// for example, accepting CONNECT request if they contain a password in header +// io.WriteString(h,password) +// passHash := h.Sum(nil) +// proxy.OnRequest().HandleConnectFunc(func(host string, ctx *ProxyCtx) (*ConnectAction, string) { +// c := sha1.New() +// io.WriteString(c,ctx.Req.Header.Get("X-GoProxy-Auth")) +// if c.Sum(nil) == passHash { +// return OkConnect, host +// } +// return RejectConnect, host +// }) +func (pcond *ReqProxyConds) HandleConnectFunc(f func(host string, ctx *ProxyCtx) (*ConnectAction, string)) { + pcond.HandleConnect(FuncHttpsHandler(f)) +} + +func (pcond *ReqProxyConds) HijackConnect(f func(req *http.Request, client net.Conn, ctx *ProxyCtx)) { + pcond.proxy.httpsHandlers = append(pcond.proxy.httpsHandlers, + FuncHttpsHandler(func(host string, ctx *ProxyCtx) (*ConnectAction, string) { + for _, cond := range pcond.reqConds { + if !cond.HandleReq(ctx.Req, ctx) { + return nil, "" + } + } + return &ConnectAction{Action: ConnectHijack, Hijack: f}, host + })) +} + +// ProxyConds is used to aggregate RespConditions for a ProxyHttpServer. +// Upon calling ProxyConds.Do, it will register a RespHandler that would +// handle the HTTP response from remote server if all conditions on the HTTP response are met. +type ProxyConds struct { + proxy *ProxyHttpServer + reqConds []ReqCondition + respCond []RespCondition +} + +// ProxyConds.DoFunc is equivalent to proxy.OnResponse().Do(FuncRespHandler(f)) +func (pcond *ProxyConds) DoFunc(f func(resp *http.Response, ctx *ProxyCtx) *http.Response) { + pcond.Do(FuncRespHandler(f)) +} + +// ProxyConds.Do will register the RespHandler on the proxy, h.Handle(resp,ctx) will be called on every +// request that matches the conditions aggregated in pcond. +func (pcond *ProxyConds) Do(h RespHandler) { + pcond.proxy.respHandlers = append(pcond.proxy.respHandlers, + FuncRespHandler(func(resp *http.Response, ctx *ProxyCtx) *http.Response { + for _, cond := range pcond.reqConds { + if !cond.HandleReq(ctx.Req, ctx) { + return resp + } + } + for _, cond := range pcond.respCond { + if !cond.HandleResp(resp, ctx) { + return resp + } + } + return h.Handle(resp, ctx) + })) +} + +// OnResponse is used when adding a response-filter to the HTTP proxy, usual pattern is +// proxy.OnResponse(cond1,cond2).Do(handler) // handler.Handle(resp,ctx) will be used +// // if cond1.HandleResp(resp) && cond2.HandleResp(resp) +func (proxy *ProxyHttpServer) OnResponse(conds ...RespCondition) *ProxyConds { + return &ProxyConds{proxy, make([]ReqCondition, 0), conds} +} + +// AlwaysMitm is a HttpsHandler that always eavesdrop https connections, for example to +// eavesdrop all https connections to www.google.com, we can use +// proxy.OnRequest(goproxy.ReqHostIs("www.google.com")).HandleConnect(goproxy.AlwaysMitm) +var AlwaysMitm FuncHttpsHandler = func(host string, ctx *ProxyCtx) (*ConnectAction, string) { + return MitmConnect, host +} + +// AlwaysReject is a HttpsHandler that drops any CONNECT request, for example, this code will disallow +// connections to hosts on any other port than 443 +// proxy.OnRequest(goproxy.Not(goproxy.ReqHostMatches(regexp.MustCompile(":443$"))). +// HandleConnect(goproxy.AlwaysReject) +var AlwaysReject FuncHttpsHandler = func(host string, ctx *ProxyCtx) (*ConnectAction, string) { + return RejectConnect, host +} + +// HandleBytes will return a RespHandler that read the entire body of the request +// to a byte array in memory, would run the user supplied f function on the byte arra, +// and will replace the body of the original response with the resulting byte array. +func HandleBytes(f func(b []byte, ctx *ProxyCtx) []byte) RespHandler { + return FuncRespHandler(func(resp *http.Response, ctx *ProxyCtx) *http.Response { + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + ctx.Warnf("Cannot read response %s", err) + return resp + } + resp.Body.Close() + + resp.Body = ioutil.NopCloser(bytes.NewBuffer(f(b, ctx))) + return resp + }) +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/doc.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/doc.go new file mode 100644 index 0000000..50aaa71 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/doc.go @@ -0,0 +1,100 @@ +/* +Package goproxy provides a customizable HTTP proxy, +supporting hijacking HTTPS connection. + +The intent of the proxy, is to be usable with reasonable amount of traffic +yet, customizable and programable. + +The proxy itself is simply an `net/http` handler. + +Typical usage is + + proxy := goproxy.NewProxyHttpServer() + proxy.OnRequest(..conditions..).Do(..requesthandler..) + proxy.OnRequest(..conditions..).DoFunc(..requesthandlerFunction..) + proxy.OnResponse(..conditions..).Do(..responesHandler..) + proxy.OnResponse(..conditions..).DoFunc(..responesHandlerFunction..) + http.ListenAndServe(":8080", proxy) + +Adding a header to each request + + proxy.OnRequest().DoFunc(func(r *http.Request,ctx *goproxy.ProxyCtx) (*http.Request, *http.Response){ + r.Header.Set("X-GoProxy","1") + return r, nil + }) + +Note that the function is called before the proxy sends the request to the server + +For printing the content type of all incoming responses + + proxy.OnResponse().DoFunc(func(r *http.Response, ctx *goproxy.ProxyCtx)*http.Response{ + println(ctx.Req.Host,"->",r.Header.Get("Content-Type")) + return r + }) + +note that we used the ProxyCtx context variable here. It contains the request +and the response (Req and Resp, Resp is nil if unavailable) of this specific client +interaction with the proxy. + +To print the content type of all responses from a certain url, we'll add a +ReqCondition to the OnResponse function: + + proxy.OnResponse(goproxy.UrlIs("golang.org/pkg")).DoFunc(func(r *http.Response, ctx *goproxy.ProxyCtx)*http.Response{ + println(ctx.Req.Host,"->",r.Header.Get("Content-Type")) + return r + }) + +We can write the condition ourselves, conditions can be set on request and on response + + var random = ReqConditionFunc(func(r *http.Request) bool { + return rand.Intn(1) == 0 + }) + var hasGoProxyHeader = RespConditionFunc(func(resp *http.Response,req *http.Request)bool { + return resp.Header.Get("X-GoProxy") != "" + }) + +Caution! If you give a RespCondition to the OnRequest function, you'll get a run time panic! It doesn't +make sense to read the response, if you still haven't got it! + +Finally, we have convenience function to throw a quick response + + proxy.OnResponse(hasGoProxyHeader).DoFunc(func(r*http.Response,ctx *goproxy.ProxyCtx)*http.Response { + r.Body.Close() + return goproxy.ForbiddenTextResponse(ctx.Req,"Can't see response with X-GoProxy header!") + }) + +we close the body of the original repsonse, and return a new 403 response with a short message. + +Example use cases: + +1. https://github.com/elazarl/goproxy/tree/master/examples/goproxy-avgsize + +To measure the average size of an Html served in your site. One can ask +all the QA team to access the website by a proxy, and the proxy will +measure the average size of all text/html responses from your host. + +2. [not yet implemented] + +All requests to your web servers should be directed through the proxy, +when the proxy will detect html pieces sent as a response to AJAX +request, it'll send a warning email. + +3. https://github.com/elazarl/goproxy/blob/master/examples/goproxy-httpdump/ + +Generate a real traffic to your website by real users using through +proxy. Record the traffic, and try it again for more real load testing. + +4. https://github.com/elazarl/goproxy/tree/master/examples/goproxy-no-reddit-at-worktime + +Will allow browsing to reddit.com between 8:00am and 17:00pm + +5. https://github.com/elazarl/goproxy/tree/master/examples/goproxy-jquery-version + +Will warn if multiple versions of jquery are used in the same domain. + +6. https://github.com/elazarl/goproxy/blob/master/examples/goproxy-upside-down-ternet/ + +Modifies image files in an HTTP response via goproxy's image extension found in ext/. + +*/ +package goproxy diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/https.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/https.go new file mode 100644 index 0000000..1341fbc --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/https.go @@ -0,0 +1,370 @@ +package goproxy + +import ( + "bufio" + "crypto/tls" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync/atomic" +) + +type ConnectActionLiteral int + +const ( + ConnectAccept = iota + ConnectReject + ConnectMitm + ConnectHijack + ConnectHTTPMitm + ConnectProxyAuthHijack +) + +var ( + OkConnect = &ConnectAction{Action: ConnectAccept, TLSConfig: TLSConfigFromCA(&GoproxyCa)} + MitmConnect = &ConnectAction{Action: ConnectMitm, TLSConfig: TLSConfigFromCA(&GoproxyCa)} + HTTPMitmConnect = &ConnectAction{Action: ConnectHTTPMitm, TLSConfig: TLSConfigFromCA(&GoproxyCa)} + RejectConnect = &ConnectAction{Action: ConnectReject, TLSConfig: TLSConfigFromCA(&GoproxyCa)} +) + +type ConnectAction struct { + Action ConnectActionLiteral + Hijack func(req *http.Request, client net.Conn, ctx *ProxyCtx) + TLSConfig func(host string, ctx *ProxyCtx) (*tls.Config, error) +} + +func stripPort(s string) string { + ix := strings.IndexRune(s, ':') + if ix == -1 { + return s + } + return s[:ix] +} + +func (proxy *ProxyHttpServer) dial(network, addr string) (c net.Conn, err error) { + if proxy.Tr.Dial != nil { + return proxy.Tr.Dial(network, addr) + } + return net.Dial(network, addr) +} + +func (proxy *ProxyHttpServer) connectDial(network, addr string) (c net.Conn, err error) { + if proxy.ConnectDial == nil { + return proxy.dial(network, addr) + } + return proxy.ConnectDial(network, addr) +} + +func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request) { + ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy} + + hij, ok := w.(http.Hijacker) + if !ok { + panic("httpserver does not support hijacking") + } + + proxyClient, _, e := hij.Hijack() + if e != nil { + panic("Cannot hijack connection " + e.Error()) + } + + ctx.Logf("Running %d CONNECT handlers", len(proxy.httpsHandlers)) + todo, host := OkConnect, r.URL.Host + for i, h := range proxy.httpsHandlers { + newtodo, newhost := h.HandleConnect(host, ctx) + + // If found a result, break the loop immediately + if newtodo != nil { + todo, host = newtodo, newhost + ctx.Logf("on %dth handler: %v %s", i, todo, host) + break + } + } + switch todo.Action { + case ConnectAccept: + if !hasPort.MatchString(host) { + host += ":80" + } + targetSiteCon, err := proxy.connectDial("tcp", host) + if err != nil { + httpError(proxyClient, ctx, err) + return + } + ctx.Logf("Accepting CONNECT to %s", host) + proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + go copyAndClose(ctx, targetSiteCon, proxyClient) + go copyAndClose(ctx, proxyClient, targetSiteCon) + case ConnectHijack: + ctx.Logf("Hijacking CONNECT to %s", host) + proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + todo.Hijack(r, proxyClient, ctx) + case ConnectHTTPMitm: + proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") + targetSiteCon, err := proxy.connectDial("tcp", host) + if err != nil { + ctx.Warnf("Error dialing to %s: %s", host, err.Error()) + return + } + for { + client := bufio.NewReader(proxyClient) + remote := bufio.NewReader(targetSiteCon) + req, err := http.ReadRequest(client) + if err != nil && err != io.EOF { + ctx.Warnf("cannot read request of MITM HTTP client: %+#v", err) + } + if err != nil { + return + } + req, resp := proxy.filterRequest(req, ctx) + if resp == nil { + if err := req.Write(targetSiteCon); err != nil { + httpError(proxyClient, ctx, err) + return + } + resp, err = http.ReadResponse(remote, req) + if err != nil { + httpError(proxyClient, ctx, err) + return + } + } + resp = proxy.filterResponse(resp, ctx) + if err := resp.Write(proxyClient); err != nil { + httpError(proxyClient, ctx, err) + return + } + } + case ConnectMitm: + proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + ctx.Logf("Assuming CONNECT is TLS, mitm proxying it") + // this goes in a separate goroutine, so that the net/http server won't think we're + // still handling the request even after hijacking the connection. Those HTTP CONNECT + // request can take forever, and the server will be stuck when "closed". + // TODO: Allow Server.Close() mechanism to shut down this connection as nicely as possible + tlsConfig := defaultTLSConfig + if todo.TLSConfig != nil { + var err error + tlsConfig, err = todo.TLSConfig(host, ctx) + if err != nil { + httpError(proxyClient, ctx, err) + return + } + } + go func() { + //TODO: cache connections to the remote website + rawClientTls := tls.Server(proxyClient, tlsConfig) + if err := rawClientTls.Handshake(); err != nil { + ctx.Warnf("Cannot handshake client %v %v", r.Host, err) + return + } + defer rawClientTls.Close() + clientTlsReader := bufio.NewReader(rawClientTls) + for !isEof(clientTlsReader) { + req, err := http.ReadRequest(clientTlsReader) + if err != nil && err != io.EOF { + return + } + if err != nil { + ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err) + return + } + req.RemoteAddr = r.RemoteAddr // since we're converting the request, need to carry over the original connecting IP as well + ctx.Logf("req %v", r.Host) + req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) + + // Bug fix which goproxy fails to provide request + // information URL in the context when does HTTPS MITM + ctx.Req = req + + req, resp := proxy.filterRequest(req, ctx) + if resp == nil { + if err != nil { + ctx.Warnf("Illegal URL %s", "https://"+r.Host+req.URL.Path) + return + } + removeProxyHeaders(ctx, req) + resp, err = ctx.RoundTrip(req) + if err != nil { + ctx.Warnf("Cannot read TLS response from mitm'd server %v", err) + return + } + ctx.Logf("resp %v", resp.Status) + } + resp = proxy.filterResponse(resp, ctx) + text := resp.Status + statusCode := strconv.Itoa(resp.StatusCode) + " " + if strings.HasPrefix(text, statusCode) { + text = text[len(statusCode):] + } + // always use 1.1 to support chunked encoding + if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+text+"\r\n"); err != nil { + ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err) + return + } + // Since we don't know the length of resp, return chunked encoded response + // TODO: use a more reasonable scheme + resp.Header.Del("Content-Length") + resp.Header.Set("Transfer-Encoding", "chunked") + if err := resp.Header.Write(rawClientTls); err != nil { + ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err) + return + } + if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { + ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err) + return + } + chunked := newChunkedWriter(rawClientTls) + if _, err := io.Copy(chunked, resp.Body); err != nil { + ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err) + return + } + if err := chunked.Close(); err != nil { + ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err) + return + } + if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { + ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err) + return + } + } + ctx.Logf("Exiting on EOF") + }() + case ConnectProxyAuthHijack: + proxyClient.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n")) + todo.Hijack(r, proxyClient, ctx) + case ConnectReject: + if ctx.Resp != nil { + if err := ctx.Resp.Write(proxyClient); err != nil { + ctx.Warnf("Cannot write response that reject http CONNECT: %v", err) + } + } + proxyClient.Close() + } +} + +func httpError(w io.WriteCloser, ctx *ProxyCtx, err error) { + if _, err := io.WriteString(w, "HTTP/1.1 502 Bad Gateway\r\n\r\n"); err != nil { + ctx.Warnf("Error responding to client: %s", err) + } + if err := w.Close(); err != nil { + ctx.Warnf("Error closing client connection: %s", err) + } +} + +func copyAndClose(ctx *ProxyCtx, w, r net.Conn) { + connOk := true + if _, err := io.Copy(w, r); err != nil { + connOk = false + ctx.Warnf("Error copying to client: %s", err) + } + if err := r.Close(); err != nil && connOk { + ctx.Warnf("Error closing: %s", err) + } +} + +func dialerFromEnv(proxy *ProxyHttpServer) func(network, addr string) (net.Conn, error) { + https_proxy := os.Getenv("HTTPS_PROXY") + if https_proxy == "" { + https_proxy = os.Getenv("https_proxy") + } + if https_proxy == "" { + return nil + } + return proxy.NewConnectDialToProxy(https_proxy) +} + +func (proxy *ProxyHttpServer) NewConnectDialToProxy(https_proxy string) func(network, addr string) (net.Conn, error) { + u, err := url.Parse(https_proxy) + if err != nil { + return nil + } + if u.Scheme == "" || u.Scheme == "http" { + if strings.IndexRune(u.Host, ':') == -1 { + u.Host += ":80" + } + return func(network, addr string) (net.Conn, error) { + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: make(http.Header), + } + c, err := proxy.dial(network, u.Host) + if err != nil { + return nil, err + } + connectReq.Write(c) + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(c) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + c.Close() + return nil, err + } + if resp.StatusCode != 200 { + resp, _ := ioutil.ReadAll(resp.Body) + c.Close() + return nil, errors.New("proxy refused connection" + string(resp)) + } + return c, nil + } + } + if u.Scheme == "https" { + if strings.IndexRune(u.Host, ':') == -1 { + u.Host += ":443" + } + return func(network, addr string) (net.Conn, error) { + c, err := proxy.dial(network, u.Host) + if err != nil { + return nil, err + } + c = tls.Client(c, proxy.Tr.TLSClientConfig) + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: make(http.Header), + } + connectReq.Write(c) + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(c) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + c.Close() + return nil, err + } + if resp.StatusCode != 200 { + body, _ := ioutil.ReadAll(io.LimitReader(resp.Body, 500)) + resp.Body.Close() + c.Close() + return nil, errors.New("proxy refused connection" + string(body)) + } + return c, nil + } + } + return nil +} + +func TLSConfigFromCA(ca *tls.Certificate) func(host string, ctx *ProxyCtx) (*tls.Config, error) { + return func(host string, ctx *ProxyCtx) (*tls.Config, error) { + config := *defaultTLSConfig + ctx.Logf("signing for %s", stripPort(host)) + cert, err := signHost(*ca, []string{stripPort(host)}) + if err != nil { + ctx.Warnf("Cannot sign host certificate with provided CA: %s", err) + return nil, err + } + config.Certificates = append(config.Certificates, cert) + return &config, nil + } +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/key.pem b/Godeps/_workspace/src/github.com/elazarl/goproxy/key.pem new file mode 100644 index 0000000..2438b37 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/key.pem @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQC/P0FsJomPGzvdO9yreV4/faEAZ6tDVGC+VnrxnidmahUd+X7Y +2v+bR2Zb4Z05+lNyz8rN8mNgav/zjHnbh+K5HwZ1nQc61cnPIXmx6hadsEi7KvU9 +sSmBGEZAyqo5S6NgTF4tt80c8ignxdnVXPK/djGNuaNYD5L+4570da0NswIDAQAB +AoGBALzIv1b4D7ARTR3NOr6V9wArjiOtMjUrdLhO+9vIp9IEA8ZsA9gjDlCEwbkP +VDnoLjnWfraff5Os6+3JjHy1fYpUiCdnk2XA6iJSL1XWKQZPt3wOunxP4lalDgED +QTRReFbA/y/Z4kSfTXpVj68ytcvSRW/N7q5/qRtbN9804jpBAkEA0s6lvH2btSLA +mcEdwhs7zAslLbdld7rvfUeP82gPPk0S6yUqTNyikqshM9AwAktHY7WvYdKl+ghZ +HTxKVC4DoQJBAOg/IAW5RbXknP+Lf7AVtBgw3E+Yfa3mcdLySe8hjxxyZq825Zmu +Rt5Qj4Lw6ifSFNy4kiiSpE/ZCukYvUXGENMCQFkPxSWlS6tzSzuqQxBGwTSrYMG3 +wb6b06JyIXcMd6Qym9OMmBpw/J5KfnSNeDr/4uFVWQtTG5xO+pdHaX+3EQECQQDl +qcbY4iX1gWVfr2tNjajSYz751yoxVbkpiT9joiQLVXYFvpu+JYEfRzsjmWl0h2Lq +AftG8/xYmaEYcMZ6wSrRAkBUwiom98/8wZVlB6qbwhU1EKDFANvICGSWMIhPx3v7 +MJqTIj4uJhte2/uyVvZ6DC6noWYgy+kLgqG0S97tUEG8 +-----END RSA PRIVATE KEY----- diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/proxy.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/proxy.go new file mode 100644 index 0000000..e4ed060 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/proxy.go @@ -0,0 +1,162 @@ +package goproxy + +import ( + "bufio" + "io" + "log" + "net" + "net/http" + "os" + "regexp" + "sync/atomic" +) + +// The basic proxy type. Implements http.Handler. +type ProxyHttpServer struct { + // session variable must be aligned in i386 + // see http://golang.org/src/pkg/sync/atomic/doc.go#L41 + sess int64 + // setting Verbose to true will log information on each request sent to the proxy + Verbose bool + Logger *log.Logger + NonproxyHandler http.Handler + reqHandlers []ReqHandler + respHandlers []RespHandler + httpsHandlers []HttpsHandler + Tr *http.Transport + // ConnectDial will be used to create TCP connections for CONNECT requests + // if nil Tr.Dial will be used + ConnectDial func(network string, addr string) (net.Conn, error) +} + +var hasPort = regexp.MustCompile(`:\d+$`) + +func copyHeaders(dst, src http.Header) { + for k, _ := range dst { + dst.Del(k) + } + for k, vs := range src { + for _, v := range vs { + dst.Add(k, v) + } + } +} + +func isEof(r *bufio.Reader) bool { + _, err := r.Peek(1) + if err == io.EOF { + return true + } + return false +} + +func (proxy *ProxyHttpServer) filterRequest(r *http.Request, ctx *ProxyCtx) (req *http.Request, resp *http.Response) { + req = r + for _, h := range proxy.reqHandlers { + req, resp = h.Handle(r, ctx) + // non-nil resp means the handler decided to skip sending the request + // and return canned response instead. + if resp != nil { + break + } + } + return +} +func (proxy *ProxyHttpServer) filterResponse(respOrig *http.Response, ctx *ProxyCtx) (resp *http.Response) { + resp = respOrig + for _, h := range proxy.respHandlers { + ctx.Resp = resp + resp = h.Handle(resp, ctx) + } + return +} + +func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) { + r.RequestURI = "" // this must be reset when serving a request with the client + ctx.Logf("Sending request %v %v", r.Method, r.URL.String()) + // If no Accept-Encoding header exists, Transport will add the headers it can accept + // and would wrap the response body with the relevant reader. + r.Header.Del("Accept-Encoding") + // curl can add that, see + // http://homepage.ntlworld.com/jonathan.deboynepollard/FGA/web-proxy-connection-header.html + r.Header.Del("Proxy-Connection") + r.Header.Del("Proxy-Authenticate") + r.Header.Del("Proxy-Authorization") + // Connection, Authenticate and Authorization are single hop Header: + // http://www.w3.org/Protocols/rfc2616/rfc2616.txt + // 14.10 Connection + // The Connection general-header field allows the sender to specify + // options that are desired for that particular connection and MUST NOT + // be communicated by proxies over further connections. + r.Header.Del("Connection") +} + +// Standard net/http function. Shouldn't be used directly, http.Serve will use it. +func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + //r.Header["X-Forwarded-For"] = w.RemoteAddr() + if r.Method == "CONNECT" { + proxy.handleHttps(w, r) + } else { + ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy} + + var err error + ctx.Logf("Got request %v %v %v %v", r.URL.Path, r.Host, r.Method, r.URL.String()) + if !r.URL.IsAbs() { + proxy.NonproxyHandler.ServeHTTP(w, r) + return + } + r, resp := proxy.filterRequest(r, ctx) + + if resp == nil { + removeProxyHeaders(ctx, r) + resp, err = ctx.RoundTrip(r) + if err != nil { + ctx.Error = err + resp = proxy.filterResponse(nil, ctx) + if resp == nil { + ctx.Logf("error read response %v %v:", r.URL.Host, err.Error()) + http.Error(w, err.Error(), 500) + return + } + } + ctx.Logf("Received response %v", resp.Status) + } + origBody := resp.Body + resp = proxy.filterResponse(resp, ctx) + + ctx.Logf("Copying response to client %v [%d]", resp.Status, resp.StatusCode) + // http.ResponseWriter will take care of filling the correct response length + // Setting it now, might impose wrong value, contradicting the actual new + // body the user returned. + // We keep the original body to remove the header only if things changed. + // This will prevent problems with HEAD requests where there's no body, yet, + // the Content-Length header should be set. + if origBody != resp.Body { + resp.Header.Del("Content-Length") + } + copyHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + nr, err := io.Copy(w, resp.Body) + if err := resp.Body.Close(); err != nil { + ctx.Warnf("Can't close response body %v", err) + } + ctx.Logf("Copied %v bytes to client error=%v", nr, err) + } +} + +// New proxy server, logs to StdErr by default +func NewProxyHttpServer() *ProxyHttpServer { + proxy := ProxyHttpServer{ + Logger: log.New(os.Stderr, "", log.LstdFlags), + reqHandlers: []ReqHandler{}, + respHandlers: []RespHandler{}, + httpsHandlers: []HttpsHandler{}, + NonproxyHandler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + http.Error(w, "This is a proxy server. Does not respond to non-proxy requests.", 500) + }), + Tr: &http.Transport{TLSClientConfig: tlsClientSkipVerify, + Proxy: http.ProxyFromEnvironment}, + } + proxy.ConnectDial = dialerFromEnv(&proxy) + return &proxy +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/responses.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/responses.go new file mode 100644 index 0000000..b304b88 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/responses.go @@ -0,0 +1,38 @@ +package goproxy + +import ( + "bytes" + "io/ioutil" + "net/http" +) + +// Will generate a valid http response to the given request the response will have +// the given contentType, and http status. +// Typical usage, refuse to process requests to local addresses: +// +// proxy.OnRequest(IsLocalHost()).DoFunc(func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request,*http.Response) { +// return nil,NewResponse(r,goproxy.ContentTypeHtml,http.StatusUnauthorized, +// `<!doctype html><html><head><title>Can't use proxy for local addresses</title></head><body/></html>`) +// }) +func NewResponse(r *http.Request, contentType string, status int, body string) *http.Response { + resp := &http.Response{} + resp.Request = r + resp.TransferEncoding = r.TransferEncoding + resp.Header = make(http.Header) + resp.Header.Add("Content-Type", contentType) + resp.StatusCode = status + buf := bytes.NewBufferString(body) + resp.ContentLength = int64(buf.Len()) + resp.Body = ioutil.NopCloser(buf) + return resp +} + +const ( + ContentTypeText = "text/plain" + ContentTypeHtml = "text/html" +) + +// Alias for NewResponse(r,ContentTypeText,http.StatusAccepted,text) +func TextResponse(r *http.Request, text string) *http.Response { + return NewResponse(r, ContentTypeText, http.StatusAccepted, text) +} diff --git a/Godeps/_workspace/src/github.com/elazarl/goproxy/signer.go b/Godeps/_workspace/src/github.com/elazarl/goproxy/signer.go new file mode 100644 index 0000000..f6d99fc --- /dev/null +++ b/Godeps/_workspace/src/github.com/elazarl/goproxy/signer.go @@ -0,0 +1,87 @@ +package goproxy + +import ( + "crypto/rsa" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "runtime" + "sort" + "time" +) + +func hashSorted(lst []string) []byte { + c := make([]string, len(lst)) + copy(c, lst) + sort.Strings(c) + h := sha1.New() + for _, s := range c { + h.Write([]byte(s + ",")) + } + return h.Sum(nil) +} + +func hashSortedBigInt(lst []string) *big.Int { + rv := new(big.Int) + rv.SetBytes(hashSorted(lst)) + return rv +} + +var goproxySignerVersion = ":goroxy1" + +func signHost(ca tls.Certificate, hosts []string) (cert tls.Certificate, err error) { + var x509ca *x509.Certificate + + // Use the provided ca and not the global GoproxyCa for certificate generation. + if x509ca, err = x509.ParseCertificate(ca.Certificate[0]); err != nil { + return + } + start := time.Unix(0, 0) + end, err := time.Parse("2006-01-02", "2049-12-31") + if err != nil { + panic(err) + } + hash := hashSorted(append(hosts, goproxySignerVersion, ":"+runtime.Version())) + serial := new(big.Int) + serial.SetBytes(hash) + template := x509.Certificate{ + // TODO(elazar): instead of this ugly hack, just encode the certificate and hash the binary form. + SerialNumber: serial, + Issuer: x509ca.Subject, + Subject: pkix.Name{ + Organization: []string{"GoProxy untrusted MITM proxy Inc"}, + }, + NotBefore: start, + NotAfter: end, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + var csprng CounterEncryptorRand + if csprng, err = NewCounterEncryptorRandFromKey(ca.PrivateKey, hash); err != nil { + return + } + var certpriv *rsa.PrivateKey + if certpriv, err = rsa.GenerateKey(&csprng, 1024); err != nil { + return + } + var derBytes []byte + if derBytes, err = x509.CreateCertificate(&csprng, &template, x509ca, &certpriv.PublicKey, ca.PrivateKey); err != nil { + return + } + return tls.Certificate{ + Certificate: [][]byte{derBytes, ca.Certificate[0]}, + PrivateKey: certpriv, + }, nil +} diff --git a/Godeps/_workspace/src/github.com/vulcand/oxy/LICENSE b/Godeps/_workspace/src/github.com/vulcand/oxy/LICENSE new file mode 100644 index 0000000..e06d208 --- /dev/null +++ b/Godeps/_workspace/src/github.com/vulcand/oxy/LICENSE @@ -0,0 +1,202 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/Makefile b/Makefile index dbdc466..1c5da46 100644 --- a/Makefile +++ b/Makefile @@ -3,15 +3,17 @@ NAME=keycloak-proxy AUTHOR=gambol99 AUTHOR_EMAIL=gambol99@gmail.com REGISTRY=quay.io -GOVERSION=1.6.0 +GOVERSION=1.6.2 SUDO= ROOT_DIR=${PWD} HARDWARE=$(shell uname -m) -GIT_SHA=$(shell git --no-pager describe --tags --always --dirty) +GIT_SHA=$(shell git --no-pager describe --always --dirty) +BUILD_TIME=$(shell date -u '+%Y-%m-%d_%I:%M:%S%p') VERSION ?= $(shell awk '/version.*=/ { print $$3 }' doc.go | sed 's/"//g') DEPS=$(shell go list -f '{{range .TestImports}}{{.}} {{end}}' ./...) PACKAGES=$(shell go list ./...) -VETARGS?=-asmdecl -atomic -bool -buildtags -copylocks -methods -nilfunc -printf -rangeloops -shift -structtags -unsafeptr +LFLAGS ?= -X main.gitsha=${GIT_SHA} +VETARGS ?= -asmdecl -atomic -bool -buildtags -copylocks -methods -nilfunc -printf -rangeloops -shift -structtags -unsafeptr .PHONY: test authors changelog build docker static release lint cover vet @@ -27,12 +29,12 @@ version: build: @echo "--> Compiling the project" mkdir -p bin - godep go build -o bin/${NAME} + godep go build -ldflags "${LFLAGS}" -o bin/${NAME} static: golang deps @echo "--> Compiling the static binary" mkdir -p bin - CGO_ENABLED=0 GOOS=linux godep go build -a -tags netgo -ldflags '-w' -o bin/${NAME} + CGO_ENABLED=0 GOOS=linux godep go build -a -tags netgo -ldflags "-w ${LFLAGS}" -o bin/${NAME} docker-build: @echo "--> Compiling the project" diff --git a/README.md b/README.md index 8da36f6..5d7e4ab 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ - TLS and mutual TLS support - JSON field bases access logs - Custom Sign-in and access forbidden pages + - Forwarding proxy support, to sign outbound requests + - URL Role Tokenization ---- @@ -24,63 +26,73 @@ NAME: keycloak-proxy - is a proxy using the keycloak service for auth and authorization USAGE: - keycloak-proxy [global options] command [command options] [arguments...] + keycloak-proxy [options] VERSION: - v1.0.5 + v1.1.0 (git+sha: 1209149) AUTHOR(S): Rohith <gambol99@gmail.com> COMMANDS: GLOBAL OPTIONS: - --config the path to the configuration file for the keycloak proxy - --listen "127.0.0.1:3000" the interface the service should be listening on - --client-secret the client secret used to authenticate to the oauth server - --client-id the client id used to authenticate to the oauth serves - --discovery-url the discovery url to retrieve the openid configuration - --scope [--scope option --scope option] a variable list of scopes requested when authenticating the user - --idle-duration "0" the expiration of the access token cookie, if not used within this time its removed - --redirection-url redirection url for the oauth callback url (/oauth is added) - --upstream-url "http://127.0.0.1:8081" the url for the upstream endpoint you wish to proxy to - --revocation-url "/oauth2/revoke" the url for the revocation endpoint to revoke refresh token - --store-url url for the storage subsystem, e.g redis://127.0.0.1:6379, file:///etc/tokens.file - --upstream-keepalives enables or disables the keepalive connections for upstream endpoint - --enable-refresh-tokens enables the handling of the refresh tokens - --secure-cookie enforces the cookie to be secure, default to true - --cookie-access-name "kc-access" the name of the cookie use to hold the access token - --cookie-refresh-name "kc-state" the name of the cookie used to hold the encrypted refresh token - --encryption-key the encryption key used to encrpytion the session state - --no-redirects do not have back redirects when no authentication is present, 401 them - --hostname [--hostname option --hostname option] a list of hostnames the service will respond to, defaults to all - --tls-cert the path to a certificate file used for TLS - --tls-private-key the path to the private key for TLS support - --tls-ca-certificate the path to the ca certificate used for mutual TLS - --skip-upstream-tls-verify whether to skip the verification of any upstream TLS (defaults to true) - --match-claims [--match-claims option --match-claims option] keypair values for matching access token claims e.g. aud=myapp, iss=http://example.* - --add-claims [--add-claims option --add-claims option] retrieve extra claims from the token and inject into headers, e.g given_name -> X-Auth-Given-Name - --resource [--resource option --resource option] a list of resources 'uri=/admin|methods=GET|roles=role1,role2' - --signin-page a custom template displayed for signin - --forbidden-page a custom template used for access forbidden - --tag [--tag option --tag option] keypair's passed to the templates at render,e.g title='My Page' - --cors-origins [--cors-origins option --cors-origins option] list of origins to add to the CORE origins control (Access-Control-Allow-Origin) - --cors-methods [--cors-methods option --cors-methods option] the method permitted in the access control (Access-Control-Allow-Methods) - --cors-headers [--cors-headers option --cors-headers option] a set of headers to add to the CORS access control (Access-Control-Allow-Headers) - --cors-exposes-headers [--cors-exposes-headers option --cors-exposes-headers option] set the expose cors headers access control (Access-Control-Expose-Headers) - --cors-max-age "0" the max age applied to cors headers (Access-Control-Max-Age) - --cors-credentials the credentials access control header (Access-Control-Allow-Credentials) - --headers [--headers option --headers option] Add custom headers to the upstream request, key=value - --enable-security-filter enables the security filter handler - --skip-token-verification TESTING ONLY; bypass's token verification, expiration and roles enforced - --offline-session enables the offline session of tokens via offline access (defaults false) - --json-logging switch on json logging rather than text (defaults true) - --log-requests switch on logging of all incoming requests (defaults true) - --verbose switch on debug / verbose logging - --help, -h show help - --version, -v print the version - + --config value the path to the configuration file for the keycloak proxy [$PROXY_CONFIG_FILE] + --listen value the interface the service should be listening on (default: "127.0.0.1:3000") + --client-secret value the client secret used to authenticate to the oauth server (access_type: confidential) [$PROXY_CLIENT_SECRET] + --client-id value the client id used to authenticate to the oauth service [$PROXY_CLIENT_ID] + --discovery-url value the discovery url to retrieve the openid configuration [$PROXY_DISCOVERY_URL] + --scope value a variable list of scopes requested when authenticating the user + --token-validate-only validate the token and roles only, no required implement oauth + --idle-duration value the expiration of the access token cookie, if not used within this time its removed (default: 0) + --redirection-url value redirection url for the oauth callback url (/oauth is added) [$PROXY_REDIRECTION_URL] + --revocation-url value the url for the revocation endpoint to revoke refresh token (default: "/oauth2/revoke") + --store-url value url for the storage subsystem, e.g redis://127.0.0.1:6379, file:///etc/tokens.file [$PROXY_STORE_URL] + --upstream-url value the url for the upstream endpoint you wish to proxy to [$PROXY_UPSTREAM_URL] + --upstream-keepalives enables or disables the keepalive connections for upstream endpoint + --upstream-timeout value is the maximum amount of time a dial will wait for a connect to complete (default: 10s) + --upstream-keepalive-timeout value specifies the keep-alive period for an active network connection (default: 10s) + --enable-refresh-tokens enables the handling of the refresh tokens + --secure-cookie enforces the cookie to be secure, default to true + --cookie-access-name value the name of the cookie use to hold the access token (default: "kc-access") + --cookie-refresh-name value the name of the cookie used to hold the encrypted refresh token (default: "kc-state") + --encryption-key value the encryption key used to encrpytion the session state + --no-redirects do not have back redirects when no authentication is present, 401 them + --hostname value a list of hostnames the service will respond to, defaults to all + --enable-proxy-protocol whether to enable proxy protocol + --enable-forwarding enables the forwarding proxy mode, signing outbound request + --forwarding-username value the username to use when logging into the openid provider + --forwarding-password value the password to use when logging into the openid provider + --forwarding-domains value a list of domains which should be signed, anything is just relayed + --tls-cert value the path to a certificate file used for TLS + --tls-private-key value the path to the private key for TLS support + --tls-ca-certificate value the path to the ca certificate used for mutual TLS + --skip-upstream-tls-verify whether to skip the verification of any upstream TLS (defaults to true) + --match-claims value keypair values for matching access token claims e.g. aud=myapp, iss=http://example.* + --add-claims value retrieve extra claims from the token and inject into headers, e.g given_name -> X-Auth-Given-Name + --resource value a list of resources 'uri=/admin|methods=GET|roles=role1,role2' + --headers value Add custom headers to the upstream request, key=value + --signin-page value a custom template displayed for signin + --forbidden-page value a custom template used for access forbidden + --tag value keypair's passed to the templates at render,e.g title='My Page' + --cors-origins value list of origins to add to the CORE origins control (Access-Control-Allow-Origin) + --cors-methods value the method permitted in the access control (Access-Control-Allow-Methods) + --cors-headers value a set of headers to add to the CORS access control (Access-Control-Allow-Headers) + --cors-exposes-headers value set the expose cors headers access control (Access-Control-Expose-Headers) + --cors-max-age value the max age applied to cors headers (Access-Control-Max-Age) (default: 0) + --cors-credentials the credentials access control header (Access-Control-Allow-Credentials) + --enable-security-filter enables the security filter handler + --skip-token-verification TESTING ONLY; bypass token verification, only expiration and roles enforced + --json-logging switch on json logging rather than text (defaults true) + --log-requests switch on logging of all incoming requests (defaults true) + --verbose switch on debug / verbose logging + --help, -h show help + --version, -v print the version ``` +#### **Building** + +Assuming you have make + go, simply run make (or 'make static' for static linking). You can also build via docker container: make docker-build + #### **Configuration** The configuration can come from a yaml/json file and or the command line options (note, command options have a higher priority and will override any options referenced in a config file) @@ -143,10 +155,9 @@ d) Create the various roles under the client or existing clients for authorizati ```YAML discovery-url: https://keycloak.example.com/auth/realms/<REALM_NAME> client-id: <CLIENT_ID> -client-secret: <CLIENT_SECRET> +client-secret: <CLIENT_SECRET> # require for access_type: confidential listen: 127.0.0.1:3000 redirection-url: http://127.0.0.1:3000 -refresh_session: false encryption_key: AgXa7xRcoClDEU0ZDSH4X0XhL5Qy2Z2j upstream-url: http://127.0.0.1:80 @@ -171,7 +182,7 @@ bin/keycloak-proxy \ --client-secret=<SECRET> \ --listen=127.0.0.1:3000 \ --redirection-url=http://127.0.0.1:3000 \ - --refresh-sessions=true \ + --enable-refresh-token=true \ --encryption-key=AgXa7xRcoClDEU0ZDSH4X0XhL5Qy2Z2j \ --upstream-url=http://127.0.0.1:80 \ --resource="uri=/admin|methods=GET|roles=test1,test2" \ @@ -199,6 +210,70 @@ DEBU[0002] resource access permitted: / access=permitted bearer DEBU[0002] resource access permitted: /favicon.ico access=permitted bearer=false expires=57m51.144004098s resource=/ username=gambol99@gmail.com 2016-02-06 13:59:01.856716 I | http: proxy error: dial tcp 127.0.0.1:8081: getsockopt: connection refused ``` + +#### **- Forward Signing Proxy (Experimental)** + +Lets say you have a bunch of services and you want to apply granular access controls, central auditing, authentication and authorization between endpoints. +Incoming is covered as detailed above, but you can also switch on a forwarding proxy. Your application can proxy outbound requests through the proxy; requests +will be signed with an authorization header (i.e. a JWT access token) for the other end to verify. The proxy will then take care of authenticating to the +OpenID service, refreshing the tokens etc. + +Example setup: + +You have selection of applications; lets assume to keep the example only those with a specific role per project for access i.e. Project requires project role claim, +ProjectB requires projectb role claim etc etc. You can setup the + +```YAML +# kubernetes pod example +- name: keycloak-proxy + image: quay.io/gambol99/keycloak-proxy:latest + args: + - --listen=unix:///var/run/keycloak/proxy.sock + - --enable-forwarding=true + - --forwarding-username=projecta + - --forwarding-password=some_password (better to grab from k8s secrets via env or perhaps vault?) + - --forwarding-domains=projectb.svc.cluster.local + - --forwarding-domains=projectc.svc.cluster.local + # Note: if you don't specify any forwarding domains, all domains will be signed; Also the code checks is the + # domain 'contains' the value (it's not a regex) so if you wanted to sign all requests to svc.cluster.local, just use + # svc.cluster.local + volumeMounts: + - name: keycloak-socket + mountPoint: /var/run/keycloak +- name: projecta + image: some_images + # +``` + +Project A can use the /var/run/keycloak/proxy.sock (or you can chunk it on localhost:PORT if you prefer) and setup the application via stanadrd proxy +setting is projects requests + + +#### **- URL Tokenization (in-progress)** +--- + +You can tokenize the url for an authenticated resource, extracting roles from the url itself. Say for example you have an applications where the uri comes in a namespace form, e.g. +/logs/<namespace> i.e. logs/admin/, logs/app1, logs/app2 etc. you could use + +```YAML +resources: +- uri: logs/admin + roles: [ 'admin' ] +- uri: logs/app1 + roles: [ 'app1' ] +- uri: logs/app2 + roles: [ 'app2' ] +``` + +But it could become annoying, creating roles for namespaces, updating there, then updating config here. An easier way would be map a url token to a role name. i.e. + +```YAML +resources: +- uri: logs/%role%/ +``` + +The above will extract role requirement from the url and apply to admission as per usual. /logs/admin will need a admin role, logs/app1 needs the app1 role, etc. + --- #### **- Upstream Headers** @@ -206,7 +281,7 @@ On protected resources the upstream endpoint will receive a number of headers ad ```GO # add the header to the upstream endpoint -cx.Request.Header.Add("X-Auth-UserId", id.id) +cx.Request.Header.Add("X-Auth-Userid", id.id) cx.Request.Header.Add("X-Auth-Subject", id.preferredName) cx.Request.Header.Add("X-Auth-Username", id.name) cx.Request.Header.Add("X-Auth-Email", id.email) @@ -215,8 +290,11 @@ cx.Request.Header.Add("X-Auth-Token", id.token.Encode()) cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ",")) # plus the default -cx.Request.Header.Add("X-Forwarded-For", <CLIENT_IP>) +cx.Request.Header.Add("X-Forwarded-For", cx.Request.RemoteAddr) cx.Request.Header.Add("X-Forwarded-Proto", <CLIENT_PROTO>) +cx.Request.Header.Set("X-Forwarded-Agent", prog) +cx.Request.Header.Set("X-Forwarded-Agent-Version", version) +cx.Request.Header.Set("X-Forwarded-Host", cx.Request.Host) ``` #### **- Custom Claims** @@ -259,6 +337,12 @@ X-Auth-Subject: rohith.jayawardene In order to remain stateless and not have to rely on a central cache to persist the 'refresh_tokens', the refresh token is encrypted and added as a cookie using *crypto/aes*. Naturally the key must be the same if your running behind a load balancer etc. The key length should either 16 or 32 bytes depending or whether you want AES-128 or AES-256. +#### **- ClientID & Secret** + +Note, the client secret is optional are is only only for setups where the oauth provider is using access_type = confidential; if the provider is 'public' simple add the client id. +Alternatively, you might not need the proxy to perform the oauth authentication flow and instead simply verify the identity token (a potential role permissions), in which case, again +just drop the client secret and use the client id and discovery-url. + #### **- Claim Matching** The proxy supports adding a variable list of claim matches against the presented tokens for additional access control. So for example you can match the 'iss' or 'aud' to the token or custom attributes; @@ -369,7 +453,7 @@ the TLS verification via the --skip-upstream-tls-verify or config option, along * **/oauth/authorize** is authentication endpoint which will generate the openid redirect to the provider * **/oauth/callback** is provider openid callback endpoint * **/oauth/expired** is a helper endpoint to check if a access token has expired, 200 for ok and, 401 for no token and 401 for expired -* **/oauth/health** is the health checking endpoint for the proxy +* **/oauth/health** is the health checking endpoint for the proxy, you can also grab version from headers * **/oauth/login** provides a relay endpoint to login via grant_type=password i.e. POST /oauth/login?username=USERNAME&password=PASSWORD * **/oauth/logout** provides a convenient endpoint to log the user out, it will always attempt to perform a back channel logout of offline tokens * **/oauth/token** is a helper endpoint which will display the current access token for you diff --git a/config.go b/config.go index 819ef7d..6aaa528 100644 --- a/config.go +++ b/config.go @@ -23,6 +23,7 @@ import ( "path/filepath" "regexp" "strings" + "time" "github.com/codegangsta/cli" "gopkg.in/yaml.v2" @@ -31,28 +32,22 @@ import ( // newDefaultConfig returns a initialized config func newDefaultConfig() *Config { return &Config{ - Listen: "127.0.0.1:3000", - RedirectionURL: "http://127.0.0.1:3000", - Upstream: "http://127.0.0.1:8081", - TagData: make(map[string]string, 0), - MatchClaims: make(map[string]string, 0), - Headers: make(map[string]string, 0), - CookieAccessName: cookieAccessToken, - CookieRefreshName: cookieRefreshToken, - SecureCookie: true, - SkipUpstreamTLSVerify: true, - CrossOrigin: CORS{}, + Listen: "127.0.0.1:3000", + TagData: make(map[string]string, 0), + MatchClaims: make(map[string]string, 0), + Headers: make(map[string]string, 0), + UpstreamTimeout: time.Duration(10) * time.Second, + UpstreamKeepaliveTimeout: time.Duration(10) * time.Second, + CookieAccessName: "kc-access", + CookieRefreshName: "kc-state", + SecureCookie: true, + SkipUpstreamTLSVerify: true, + CrossOrigin: CORS{}, } } // isValid validates if the config is valid func (r *Config) isValid() error { - if r.Upstream == "" { - return fmt.Errorf("you have not specified an upstream endpoint to proxy to") - } - if _, err := url.Parse(r.Upstream); err != nil { - return fmt.Errorf("the upstream endpoint is invalid, %s", err) - } if r.Listen == "" { return fmt.Errorf("you have not specified the listening interface") } @@ -71,49 +66,65 @@ func (r *Config) isValid() error { if r.TLSCaCertificate != "" && !fileExists(r.TLSCaCertificate) { return fmt.Errorf("the tls ca certificate file %s does not exist", r.TLSCaCertificate) } - // step: if the skip verification is off, we need the below - if !r.SkipTokenVerification { - if r.DiscoveryURL == "" { - return fmt.Errorf("you have not specified the discovery url") - } + + if r.EnableForwarding { if r.ClientID == "" { return fmt.Errorf("you have not specified the client id") } - if r.ClientSecret == "" { - return fmt.Errorf("you have not specified the client secret") - } - if r.RedirectionURL == "" { - return fmt.Errorf("you have not specified the redirection url") + if r.DiscoveryURL == "" { + return fmt.Errorf("you have not specified the discovery url") } - if strings.HasSuffix(r.RedirectionURL, "/") { - r.RedirectionURL = strings.TrimSuffix(r.RedirectionURL, "/") + if r.ForwardingUsername == "" { + return fmt.Errorf("no forwarding username") } - if r.EnableRefreshTokens && r.EncryptionKey == "" { - return fmt.Errorf("you have not specified a encryption key for encoding the session state") + if r.ForwardingPassword == "" { + return fmt.Errorf("no forwarding password") } - if r.EnableRefreshTokens && (len(r.EncryptionKey) != 16 && len(r.EncryptionKey) != 32) { - return fmt.Errorf("the encryption key (%d) must be either 16 or 32 characters for AES-128/AES-256 selection", len(r.EncryptionKey)) + } else { + if r.Upstream == "" { + return fmt.Errorf("you have not specified an upstream endpoint to proxy to") } - if r.SecureCookie && !strings.HasPrefix(r.RedirectionURL, "https") { - return fmt.Errorf("the cookie is set to secure but your redirection url is non-tls") + if _, err := url.Parse(r.Upstream); err != nil { + return fmt.Errorf("the upstream endpoint is invalid, %s", err) } - if r.StoreURL != "" { - if _, err := url.Parse(r.StoreURL); err != nil { - return fmt.Errorf("the store url is invalid, error: %s", err) + // step: if the skip verification is off, we need the below + if !r.SkipTokenVerification { + if r.ClientID == "" { + return fmt.Errorf("you have not specified the client id") + } + if r.DiscoveryURL == "" { + return fmt.Errorf("you have not specified the discovery url") + } + if strings.HasSuffix(r.RedirectionURL, "/") { + r.RedirectionURL = strings.TrimSuffix(r.RedirectionURL, "/") + } + if r.EnableRefreshTokens && r.EncryptionKey == "" { + return fmt.Errorf("you have not specified a encryption key for encoding the session state") + } + if r.EnableRefreshTokens && (len(r.EncryptionKey) != 16 && len(r.EncryptionKey) != 32) { + return fmt.Errorf("the encryption key (%d) must be either 16 or 32 characters for AES-128/AES-256 selection", len(r.EncryptionKey)) + } + if r.SecureCookie && !strings.HasPrefix(r.RedirectionURL, "https") { + return fmt.Errorf("the cookie is set to secure but your redirection url is non-tls") + } + if r.StoreURL != "" { + if _, err := url.Parse(r.StoreURL); err != nil { + return fmt.Errorf("the store url is invalid, error: %s", err) + } } } - } - // step: valid the resources - for _, resource := range r.Resources { - if err := resource.IsValid(); err != nil { - return err + // step: valid the resources + for _, resource := range r.Resources { + if err := resource.IsValid(); err != nil { + return err + } } - } - // step: validate the claims are validate regex's - for k, claim := range r.MatchClaims { - // step: validate the regex - if _, err := regexp.Compile(claim); err != nil { - return fmt.Errorf("the claim matcher: %s for claim: %s is not a valid regex", claim, k) + // step: validate the claims are validate regex's + for k, claim := range r.MatchClaims { + // step: validate the regex + if _, err := regexp.Compile(claim); err != nil { + return fmt.Errorf("the claim matcher: %s for claim: %s is not a valid regex", claim, k) + } } } @@ -138,7 +149,10 @@ func (r *Config) hasCustomForbiddenPage() bool { return false } +// // readOptions parses the command line options and constructs a config object +// @TODO look for a shorter way of doing this, we're maintaining the same options in multiple places, it's tedious! +// func readOptions(cx *cli.Context, config *Config) (err error) { if cx.IsSet("listen") { config.Listen = cx.String("listen") @@ -161,6 +175,12 @@ func readOptions(cx *cli.Context, config *Config) (err error) { if cx.IsSet("upstream-keepalives") { config.UpstreamKeepalives = cx.Bool("upstream-keepalives") } + if cx.IsSet("upstream-timeout") { + config.UpstreamTimeout = cx.Duration("upstream-timeout") + } + if cx.IsSet("upstream-keepalive-timeout") { + config.UpstreamKeepaliveTimeout = cx.Duration("upstream-keepalive-timeout") + } if cx.IsSet("idle-duration") { config.IdleDuration = cx.Duration("idle-duration") } @@ -206,6 +226,21 @@ func readOptions(cx *cli.Context, config *Config) (err error) { if cx.IsSet("tls-ca-certificate") { config.TLSCaCertificate = cx.String("tls-ca-certificate") } + if cx.IsSet("enable-proxy-protocol") { + config.EnableProxyProtocol = cx.Bool("enable-proxy-protocol") + } + if cx.IsSet("enable-forwarding") { + config.EnableForwarding = cx.Bool("enable-forwarding") + } + if cx.IsSet("forwarding-username") { + config.ForwardingUsername = cx.String("forwarding-username") + } + if cx.IsSet("forwarding-password") { + config.ForwardingPassword = cx.String("forwarding-password") + } + if cx.IsSet("forwarding-domains") { + config.ForwardingDomains = append(config.ForwardingDomains, cx.StringSlice("forwarding-domains")...) + } if cx.IsSet("signin-page") { config.SignInPage = cx.String("signin-page") } @@ -215,9 +250,6 @@ func readOptions(cx *cli.Context, config *Config) (err error) { if cx.IsSet("enable-security-filter") { config.EnableSecurityFilter = true } - if cx.IsSet("proxy-protocol") { - config.ProxyProtocol = cx.Bool("proxy-protocol") - } if cx.IsSet("json-logging") { config.LogJSONFormat = cx.Bool("json-logging") } @@ -320,12 +352,12 @@ func getOptions() []cli.Flag { }, cli.StringFlag{ Name: "client-secret", - Usage: "the client secret used to authenticate to the oauth server", + Usage: "the client secret used to authenticate to the oauth server (access_type: confidential)", EnvVar: "PROXY_CLIENT_SECRET", }, cli.StringFlag{ Name: "client-id", - Usage: "the client id used to authenticate to the oauth serves", + Usage: "the client id used to authenticate to the oauth service", EnvVar: "PROXY_CLIENT_ID", }, cli.StringFlag{ @@ -337,6 +369,10 @@ func getOptions() []cli.Flag { Name: "scope", Usage: "a variable list of scopes requested when authenticating the user", }, + cli.BoolFlag{ + Name: "token-validate-only", + Usage: "validate the token and roles only, no required implement oauth", + }, cli.DurationFlag{ Name: "idle-duration", Usage: "the expiration of the access token cookie, if not used within this time its removed", @@ -346,12 +382,6 @@ func getOptions() []cli.Flag { Usage: fmt.Sprintf("redirection url for the oauth callback url (%s is added)", oauthURL), EnvVar: "PROXY_REDIRECTION_URL", }, - cli.StringFlag{ - Name: "upstream-url", - Usage: "the url for the upstream endpoint you wish to proxy to", - Value: defaults.Upstream, - EnvVar: "PROXY_UPSTREAM_URL", - }, cli.StringFlag{ Name: "revocation-url", Usage: "the url for the revocation endpoint to revoke refresh token", @@ -362,10 +392,26 @@ func getOptions() []cli.Flag { Usage: "url for the storage subsystem, e.g redis://127.0.0.1:6379, file:///etc/tokens.file", EnvVar: "PROXY_STORE_URL", }, + cli.StringFlag{ + Name: "upstream-url", + Usage: "the url for the upstream endpoint you wish to proxy to", + Value: defaults.Upstream, + EnvVar: "PROXY_UPSTREAM_URL", + }, cli.BoolTFlag{ Name: "upstream-keepalives", Usage: "enables or disables the keepalive connections for upstream endpoint", }, + cli.DurationFlag{ + Name: "upstream-timeout", + Usage: "is the maximum amount of time a dial will wait for a connect to complete", + Value: defaults.UpstreamTimeout, + }, + cli.DurationFlag{ + Name: "upstream-keepalive-timeout", + Usage: "specifies the keep-alive period for an active network connection", + Value: defaults.UpstreamKeepaliveTimeout, + }, cli.BoolFlag{ Name: "enable-refresh-tokens", Usage: "enables the handling of the refresh tokens", @@ -396,6 +442,26 @@ func getOptions() []cli.Flag { Name: "hostname", Usage: "a list of hostnames the service will respond to, defaults to all", }, + cli.BoolFlag{ + Name: "enable-proxy-protocol", + Usage: "whether to enable proxy protocol", + }, + cli.BoolFlag{ + Name: "enable-forwarding", + Usage: "enables the forwarding proxy mode, signing outbound request", + }, + cli.StringFlag{ + Name: "forwarding-username", + Usage: "the username to use when logging into the openid provider", + }, + cli.StringFlag{ + Name: "forwarding-password", + Usage: "the password to use when logging into the openid provider", + }, + cli.StringSliceFlag{ + Name: "forwarding-domains", + Usage: "a list of domains which should be signed, anything is just relayed", + }, cli.StringFlag{ Name: "tls-cert", Usage: "the path to a certificate file used for TLS", @@ -424,6 +490,10 @@ func getOptions() []cli.Flag { Name: "resource", Usage: "a list of resources 'uri=/admin|methods=GET|roles=role1,role2'", }, + cli.StringSliceFlag{ + Name: "headers", + Usage: "Add custom headers to the upstream request, key=value", + }, cli.StringFlag{ Name: "signin-page", Usage: "a custom template displayed for signin", @@ -460,17 +530,13 @@ func getOptions() []cli.Flag { Name: "cors-credentials", Usage: "the credentials access control header (Access-Control-Allow-Credentials)", }, - cli.StringSliceFlag{ - Name: "headers", - Usage: "Add custom headers to the upstream request, key=value", - }, cli.BoolFlag{ Name: "enable-security-filter", Usage: "enables the security filter handler", }, cli.BoolFlag{ Name: "skip-token-verification", - Usage: "TESTING ONLY; bypass's token verification, expiration and roles enforced", + Usage: "TESTING ONLY; bypass token verification, only expiration and roles enforced", }, cli.BoolTFlag{ Name: "json-logging", diff --git a/config_sample.yml b/config_sample.yml index 2b7d8e4..55abd18 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -4,7 +4,8 @@ discovery-url: https://keycloak.example.com/auth/realms/commons # the client id for the 'client' application client-id: <CLIENT_ID> -# the secret associated to the 'client' application +# the secret associated to the 'client' application - note the client_secret is optional, required for +# oauth2 access_type=confidential i.e. the client is being verified client-secret: <CLIENT_SECRET> # the interface definition you wish the proxy to listen, all interfaces is specified as ':<port>' listen: 127.0.0.1:3000 diff --git a/cookies_test.go b/cookies_test.go index 278797a..6649da9 100644 --- a/cookies_test.go +++ b/cookies_test.go @@ -28,7 +28,7 @@ func TestDropCookie(t *testing.T) { p.dropCookie(context, "test-cookie", "test-value", 0) assert.Equal(t, context.Writer.Header().Get("Set-Cookie"), - "test-cookie=test-value; Path=/; Domain=127.0.0.1; Secure", + "test-cookie=test-value; Path=/; Domain=127.0.0.1", "we have not set the cookie, headers: %v", context.Writer.Header()) context = newFakeGinContext("GET", "/admin") diff --git a/doc.go b/doc.go index 05b9cb5..d9d6649 100644 --- a/doc.go +++ b/doc.go @@ -20,18 +20,22 @@ import ( "time" ) +var ( + release = "v1.1.0" + gitsha = "no gitsha provided" + version = release + " (git+sha: " + gitsha + ")" +) + const ( prog = "keycloak-proxy" - version = "v1.0.6" author = "Rohith" email = "gambol99@gmail.com" description = "is a proxy using the keycloak service for auth and authorization" headerUpgrade = "Upgrade" - cookieAccessToken = "kc-access" - cookieRefreshToken = "kc-state" userContextName = "identity" authorizationHeader = "Authorization" + versionHeader = "X-Auth-Proxy-Version" oauthURL = "/oauth" authorizationURL = "/authorize" @@ -94,46 +98,41 @@ type CORS struct { // Config is the configuration for the proxy type Config struct { - // LogRequests indicates if we should log all the requests - LogRequests bool `json:"log-requests" yaml:"log-requests"` - // LogFormat is the logging format - LogJSONFormat bool `json:"log-json-format" yaml:"log-json-format"` + // Listen is the binding interface + Listen string `json:"listen" yaml:"listen"` // DiscoveryURL is the url for the keycloak server DiscoveryURL string `json:"discovery-url" yaml:"discovery-url"` // ClientID is the client id ClientID string `json:"client-id" yaml:"client-id"` // ClientSecret is the secret for AS ClientSecret string `json:"client-secret" yaml:"client-secret"` - // RevocationEndpoint is the token revocation endpoint to revoke refresh tokens - RevocationEndpoint string `json:"revocation-url" yaml:"revocation-url"` - // NoRedirects informs we should hand back a 401 not a redirect - NoRedirects bool `json:"no-redirects" yaml:"no-redirects"` // RedirectionURL the redirection url RedirectionURL string `json:"redirection-url" yaml:"redirection-url"` - // EnableSecurityFilter enabled the security handler - EnableSecurityFilter bool `json:"enable-security-filter" yaml:"enable-security-filter"` - // EnableRefreshTokens indicate's you wish to ignore using refresh tokens and re-auth on expireation of access token - EnableRefreshTokens bool `json:"enable-refresh-tokens" yaml:"enable-refresh-tokens"` + // RevocationEndpoint is the token revocation endpoint to revoke refresh tokens + RevocationEndpoint string `json:"revocation-url" yaml:"revocation-url"` + // Scopes is a list of scope we should request + Scopes []string `json:"scopes" yaml:"scopes"` + // Upstream is the upstream endpoint i.e whom were proxying to + Upstream string `json:"upstream-url" yaml:"upstream-url"` + // Resources is a list of protected resources + Resources []*Resource `json:"resources" yaml:"resources"` + // Headers permits adding customs headers across the board + Headers map[string]string `json:"headers" yaml:"headers"` + // CookieAccessName is the name of the access cookie holding the access token CookieAccessName string `json:"cookie-access-name" yaml:"cookie-access-name"` // CookieRefreshName is the name of the refresh cookie CookieRefreshName string `json:"cookie-refresh-name" yaml:"cookie-refresh-name"` // SecureCookie enforces the cookie as secure SecureCookie bool `json:"secure-cookie" yaml:"secure-cookie"` + // IdleDuration is the max amount of time a session can last without being used IdleDuration time.Duration `json:"idle-duration" yaml:"idle-duration"` - // EncryptionKey is the encryption key used to encrypt the refresh token - EncryptionKey string `json:"encryption-key" yaml:"encryption-key"` // MatchClaims is a series of checks, the claims in the token must match those here MatchClaims map[string]string `json:"match-claims" yaml:"match-claims"` // AddClaims is a series of claims that should be added to the auth headers AddClaims []string `json:"add-claims" yaml:"add-claims"` - // UpstreamKeepalives specifies wheather we use keepalives on the upstream - UpstreamKeepalives bool `json:"upstream-keepalives" yaml:"upstream-keepalives"` - // Listen is the binding interface - Listen string `json:"listen" yaml:"listen"` - // ProxyProtocol enables proxy protocol - ProxyProtocol bool `json:"proxy-protocol" yaml:"proxy-protocol"` + // TLSCertificate is the location for a tls certificate TLSCertificate string `json:"tls-cert" yaml:"tls-cert"` // TLSPrivateKey is the location of a tls private key @@ -142,30 +141,56 @@ type Config struct { TLSCaCertificate string `json:"tls-ca-certificate" yaml:"tls-ca-certificate"` // SkipUpstreamTLSVerify skips the verification of any upstream tls SkipUpstreamTLSVerify bool `json:"skip-upstream-tls-verify" yaml:"skip-upstream-tls-verify"` - // Upstream is the upstream endpoint i.e whom were proxying to - Upstream string `json:"upstream-url" yaml:"upstream-url"` - // TagData is passed to the templates - TagData map[string]string `json:"tag-data" yaml:"tag-data"` + // CrossOrigin permits adding headers to the /oauth handlers CrossOrigin CORS `json:"cors" yaml:"cors"` - // Headers permits adding customs headers across the board - Headers map[string]string `json:"headers" yaml:"headers"` - // Scopes is a list of scope we should request - Scopes []string `json:"scopes" yaml:"scopes"` - // Resources is a list of protected resources - Resources []*Resource `json:"resources" yaml:"resources"` - // SignInPage is the relative url for the sign in page - SignInPage string `json:"sign-in-page" yaml:"sign-in-page"` - // ForbiddenPage is a access forbidden page - ForbiddenPage string `json:"forbidden-page" yaml:"forbidden-page"` - // SkipTokenVerification tells the service to skipp verifying the access token - for testing purposes - SkipTokenVerification bool `json:"skip-token-verification" yaml:"skip-token-verification"` - // Verbose switches on debug logging - Verbose bool `json:"verbose" yaml:"verbose"` + // Hostname is a list of hostname's the service should response to Hostnames []string `json:"hostnames" yaml:"hostnames"` + // Store is a url for a store resource, used to hold the refresh tokens StoreURL string `json:"store-url" yaml:"store-url"` + // EncryptionKey is the encryption key used to encrypt the refresh token + EncryptionKey string `json:"encryption-key" yaml:"encryption-key"` + + // EnableSecurityFilter enabled the security handler + EnableSecurityFilter bool `json:"enable-security-filter" yaml:"enable-security-filter"` + // EnableRefreshTokens indicate's you wish to ignore using refresh tokens and re-auth on expiration of access token + EnableRefreshTokens bool `json:"enable-refresh-tokens" yaml:"enable-refresh-tokens"` + // LogRequests indicates if we should log all the requests + LogRequests bool `json:"log-requests" yaml:"log-requests"` + // LogFormat is the logging format + LogJSONFormat bool `json:"log-json-format" yaml:"log-json-format"` + // NoRedirects informs we should hand back a 401 not a redirect + NoRedirects bool `json:"no-redirects" yaml:"no-redirects"` + // SkipTokenVerification tells the service to skipp verifying the access token - for testing purposes + SkipTokenVerification bool `json:"skip-token-verification" yaml:"skip-token-verification"` + // UpstreamKeepalives specifies whether we use keepalives on the upstream + UpstreamKeepalives bool `json:"upstream-keepalives" yaml:"upstream-keepalives"` + // UpstreamTimeout is the maximum amount of time a dial will wait for a connect to complete + UpstreamTimeout time.Duration `json:"upstream-timeout" yaml:"upstream-timeout"` + // UpstreamKeepaliveTimeout + UpstreamKeepaliveTimeout time.Duration `json:"upstream-keepalive-timeout" yaml:"upstream-keepalive-timeout"` + // Verbose switches on debug logging + Verbose bool `json:"verbose" yaml:"verbose"` + // EnableProxyProtocol controls the proxy protocol + EnableProxyProtocol bool `json:"enabled-proxy-protocol" yaml:"enabled-proxy-protocol"` + + // SignInPage is the relative url for the sign in page + SignInPage string `json:"sign-in-page" yaml:"sign-in-page"` + // ForbiddenPage is a access forbidden page + ForbiddenPage string `json:"forbidden-page" yaml:"forbidden-page"` + // TagData is passed to the templates + TagData map[string]string `json:"tag-data" yaml:"tag-data"` + + // EnableForwarding enables the forwarding proxy + EnableForwarding bool `json:"enable-forwarding" yaml:"enable-forwarding"` + // ForwardingUsername is the username to login to the oauth service + ForwardingUsername string `json:"forwarding-username" yaml:"forwarding-username"` + // ForwardingPassword is the password to use for the above + ForwardingPassword string `json:"forwarding-password" yaml:"forwarding-password"` + // ForwardingDomains is a collection of domains to signs + ForwardingDomains []string `json:"forwarding-domains" yaml:"forwarding-domains"` } // store is used to hold the offline refresh token, assuming you don't want to use diff --git a/forwarding.go b/forwarding.go new file mode 100644 index 0000000..2f4a87c --- /dev/null +++ b/forwarding.go @@ -0,0 +1,233 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "net/http" + "time" + + log "github.com/Sirupsen/logrus" + "github.com/coreos/go-oidc/jose" + "github.com/coreos/go-oidc/oidc" + "github.com/gin-gonic/gin" +) + +// +// upstreamReverseProxyHandler is responsible for handles reverse proxy request to the upstream endpoint +// +func (r *oauthProxy) upstreamReverseProxyHandler() gin.HandlerFunc { + return func(cx *gin.Context) { + if cx.IsAborted() { + return + } + + // step: is this connection upgrading? + if isUpgradedConnection(cx.Request) { + log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade)) + if err := tryUpdateConnection(cx, r.endpoint); err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection") + cx.AbortWithStatus(http.StatusInternalServerError) + return + } + cx.Abort() + return + } + /* + By default goproxy only provides a forwarding proxy, thus all requests have to be absolute + and we must update the host headers + */ + cx.Request.URL.Host = r.endpoint.Host + cx.Request.URL.Scheme = r.endpoint.Scheme + cx.Request.Host = r.endpoint.Host + + r.upstream.ServeHTTP(cx.Writer, cx.Request) + } +} + +// +// forwardProxyHandler is responsible for signing outbound requests +// +func (r *oauthProxy) forwardProxyHandler() gin.HandlerFunc { + var token jose.JWT + var identity *oidc.Identity + var refreshToken string + + // step: create oauth client + client, err := r.client.OAuthClient() + if err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Fatalf("failed to create an oauth, error: %s", err) + } + + // step: create a routine to refresh the access tokens or login on expiration + go func() { + // step: setup a timer to refresh the access token + requireLogin := true + var expires time.Time + + for { + waitingOn := false + + // step: do we have a access token + if requireLogin { + log.WithFields(log.Fields{ + "username": r.config.ForwardingUsername, + }).Debugf("requesting a access token for user") + + // step: login into the service + resp, err := client.UserCredsToken(r.config.ForwardingUsername, r.config.ForwardingPassword) + if err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Error("failed to login to authentication service") + + // step: backoff and reschedule + <-time.After(time.Duration(5) * time.Second) + continue + } + + // step: decode the token to find the claims + token, err = jose.ParseJWT(resp.AccessToken) + if err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("failed to parse the access token") + + // step: we should probably hope and reschedule here + <-time.After(time.Duration(5) * time.Second) + continue + } + + claims, err := token.Claims() + if err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("failed to parse claims in access token") + + <-time.After(time.Duration(5) * time.Second) + continue + } + + // step: parse the identity from the token + identity, err = oidc.IdentityFromClaims(claims) + if err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("failed to decode the identity of access token") + + // step: reschedule a reattempt in x seconds + <-time.After(time.Duration(5) * time.Second) + continue + } + + // step: print some logging for debug purposes + // step: set the expiration of the access token within a random 85% of + // actual expiration + seconds := int(float64(identity.ExpiresAt.Sub(time.Now()).Seconds()) * 0.85) + expires = time.Now().Add(time.Duration(seconds) * time.Second) + + // step: update the loop state + requireLogin = false + waitingOn = true + refreshToken = resp.RefreshToken + + log.WithFields(log.Fields{ + "subject": identity.ID, + "email": identity.Email, + "expires_on": identity.ExpiresAt.Format(time.RFC822Z), + "renewal": expires.Format(time.RFC822Z), + "duration": expires.Sub(time.Now()).String(), + }).Infof("retrieved the access token for subject") + + } else { + // step: check if the access token is about to expiry + if time.Now().After(expires) { + log.WithFields(log.Fields{ + "subject": identity.ID, + "email": identity.Email, + }).Debugf("access token is about to expiry") + // step: if we do NOT have a refresh token, we need to login again + if refreshToken == "" { + waitingOn = false + requireLogin = true + break + } + } + + log.WithFields(log.Fields{ + "subject": identity.ID, + "email": identity.Email, + "expires_on": identity.ExpiresAt.Format(time.RFC822Z), + }).Debugf("attempting to refresh the access token") + + // step: attempt to refresh the access + renewToken, expiresIn, err := getRefreshedToken(r.client, refreshToken) + if err != nil { + // step: we need to login again + requireLogin = true + // step: has the refresh token expired + switch err { + case ErrRefreshTokenExpired: + log.WithFields(log.Fields{ + "token": token, + }).Warningf("the refresh token has expired, need to login again") + default: + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("failed to refresh the access token") + } + <-time.After(time.Duration(5) * time.Second) + continue + } + + // step: update the access token + token = renewToken + expires = expiresIn + waitingOn = true + } + + // step: wait for an expiration to come close + if waitingOn { + log.WithFields(log.Fields{ + "expires": expires.String(), + }).Debugf("waiting for expiration of access token") + + <-time.After(expires.Sub(time.Now())) + } + } + }() + + return func(cx *gin.Context) { + hostname := cx.Request.Host + cx.Request.URL.Host = cx.Request.Host + + // step: does the host being signed? + // a) if the forwarding domain set and we are NOT in the list, just forward it + // b) else the list is zero (meaning sign all requests) or we are in the list + if len(r.config.ForwardingDomains) > 0 && !containedIn(hostname, r.config.ForwardingDomains) { + goto PROXY + } + + // step: sign the outbound request with the access token + cx.Request.Header.Set("X-Forwarded-Agent", prog) + cx.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Encode())) + + PROXY: + r.upstream.ServeHTTP(cx.Writer, cx.Request) + } +} diff --git a/handlers.go b/handlers.go index a127439..8a8e240 100644 --- a/handlers.go +++ b/handlers.go @@ -17,12 +17,12 @@ package main import ( "bytes" + "encoding/base64" "fmt" "io/ioutil" "net/http" "net/url" "path" - "strings" "time" log "github.com/Sirupsen/logrus" @@ -32,7 +32,7 @@ import ( // // oauthAuthorizationHandler is responsible for performing the redirection to oauth provider // -func (r oauthProxy) oauthAuthorizationHandler(cx *gin.Context) { +func (r *oauthProxy) oauthAuthorizationHandler(cx *gin.Context) { // step: we can skip all of this if were not verifying the token if r.config.SkipTokenVerification { cx.AbortWithStatus(http.StatusNotAcceptable) @@ -55,13 +55,15 @@ func (r oauthProxy) oauthAuthorizationHandler(cx *gin.Context) { accessType = "offline" } - log.WithFields(log.Fields{ - "client_ip": cx.ClientIP(), - "access_type": accessType, - }).Infof("incoming authorization request from client address: %s", cx.ClientIP()) - + // step: generate the authorization url redirectionURL := client.AuthCodeURL(cx.Query("state"), accessType, "") + log.WithFields(log.Fields{ + "client_ip": cx.ClientIP(), + "access_type": accessType, + "redirection-url": redirectionURL, + }).Debugf("incoming authorization request from client address: %s", cx.ClientIP()) + // step: if we have a custom sign in page, lets display that if r.config.hasCustomSignInPage() { // step: inject any custom tags into the context for the template @@ -81,25 +83,19 @@ func (r oauthProxy) oauthAuthorizationHandler(cx *gin.Context) { // // oauthCallbackHandler is responsible for handling the response from oauth service // -func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) { +func (r *oauthProxy) oauthCallbackHandler(cx *gin.Context) { // step: is token verification switched on? if r.config.SkipTokenVerification { cx.AbortWithStatus(http.StatusNotAcceptable) return } - code := cx.Request.URL.Query().Get("code") - state := cx.Request.URL.Query().Get("state") - // step: ensure we have a authorization code to exchange + code := cx.Request.URL.Query().Get("code") if code == "" { cx.AbortWithStatus(http.StatusBadRequest) return } - // step: ensure we have a state or default to root / - if state == "" { - state = "/" - } // step: exchange the authorization for a access token response, err := exchangeAuthenticationCode(r.client, code) @@ -180,13 +176,27 @@ func (r oauthProxy) oauthCallbackHandler(cx *gin.Context) { } } + // step: decode the state variable + state := "/" + if cx.Request.URL.Query().Get("state") != "" { + decoded, err := base64.StdEncoding.DecodeString(cx.Request.URL.Query().Get("state")) + if err != nil { + log.WithFields(log.Fields{ + "state": cx.Request.URL.Query().Get("state"), + "error": err.Error(), + }).Warnf("unabe to decode the state parameter") + } else { + state = string(decoded) + } + } + r.redirectToURL(state, cx) } // // loginHandler provide's a generic endpoint for clients to perform a user_credentials login to the provider // -func (r oauthProxy) loginHandler(cx *gin.Context) { +func (r *oauthProxy) loginHandler(cx *gin.Context) { // step: parse the client credentials username := cx.Request.URL.Query().Get("username") password := cx.Request.URL.Query().Get("password") @@ -239,7 +249,7 @@ func (r oauthProxy) loginHandler(cx *gin.Context) { // - if the user has a refresh token, the token is invalidated by the provider // - optionally, the user can be redirected by to a url // -func (r oauthProxy) logoutHandler(cx *gin.Context) { +func (r *oauthProxy) logoutHandler(cx *gin.Context) { // the user can specify a url to redirect the back to redirectURL := cx.Request.URL.Query().Get("redirect") @@ -261,7 +271,9 @@ func (r oauthProxy) logoutHandler(cx *gin.Context) { if r.useStore() { go func() { if err := r.DeleteRefreshToken(user.token); err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Errorf("unable to remove the refresh token from store") + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("unable to remove the refresh token from store") } }() } @@ -333,76 +345,6 @@ func (r oauthProxy) logoutHandler(cx *gin.Context) { cx.AbortWithStatus(http.StatusOK) } -// -// proxyHandler is responsible to proxy the requests on to the upstream endpoint -// -func (r oauthProxy) proxyHandler() gin.HandlerFunc { - // step: we don't wanna do this every time, quicker to perform once - customClaims := make(map[string]string) - for _, x := range r.config.AddClaims { - customClaims[x] = fmt.Sprintf("X-Auth-%s", toHeader(x)) - } - - return func(cx *gin.Context) { - // step: double check, if enforce is true and no user context it's a internal error - if _, found := cx.Get(cxEnforce); found { - if _, found := cx.Get(userContextName); !found { - log.Errorf("no user context found for a secure request") - cx.AbortWithStatus(http.StatusInternalServerError) - return - } - } - - // step: retrieve the user context if any - if user, found := cx.Get(userContextName); found { - id := user.(*userContext) - cx.Request.Header.Add("X-Auth-UserId", id.id) - cx.Request.Header.Add("X-Auth-Subject", id.preferredName) - cx.Request.Header.Add("X-Auth-Username", id.name) - cx.Request.Header.Add("X-Auth-Email", id.email) - cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String()) - cx.Request.Header.Add("X-Auth-Token", id.token.Encode()) - cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ",")) - - // step: inject any custom claims - for claim, header := range customClaims { - if claim, found := id.claims[claim]; found { - cx.Request.Header.Add(header, fmt.Sprintf("%v", claim)) - } - } - } - - // step: add the default headers - cx.Request.Header.Add("X-Forwarded-For", cx.Request.RemoteAddr) - cx.Request.Header.Set("X-Forwarded-Agent", prog) - cx.Request.Header.Set("X-Forwarded-Agent-Version", version) - - // step: is this connection upgrading? - if isUpgradedConnection(cx.Request) { - log.Debugf("upgrading the connnection to %s", cx.Request.Header.Get(headerUpgrade)) - if err := tryUpdateConnection(cx, r.endpoint); err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to upgrade the connection") - - cx.AbortWithStatus(http.StatusInternalServerError) - return - } - cx.Abort() - - return - } - - /* - Issue: https://github.com/golang/go/issues/7618 - - The reverse proxy does not update the Host header of request, as it's assumed the upstream in on the - same domain as the proxy. We could override the Director method, but the latter is easier - */ - cx.Request.Host = r.endpoint.Host - - r.upstream.ServeHTTP(cx.Writer, cx.Request) - } -} - // // expirationHandler checks if the token has expired // @@ -442,13 +384,14 @@ func (r *oauthProxy) tokenHandler(cx *gin.Context) { // healthHandler is a health check handler for the service // func (r *oauthProxy) healthHandler(cx *gin.Context) { - cx.String(http.StatusOK, "OK") + cx.Writer.Header().Set(versionHeader, version) + cx.String(http.StatusOK, "OK\n") } // // retrieveRefreshToken retrieves the refresh token from store or cookie // -func (r oauthProxy) retrieveRefreshToken(cx *gin.Context, user *userContext) (string, error) { +func (r *oauthProxy) retrieveRefreshToken(cx *gin.Context, user *userContext) (string, error) { var token string var err error diff --git a/handlers_test.go b/handlers_test.go index 14d124e..729611d 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/coreos/go-oidc/jose" + "github.com/stretchr/testify/assert" ) func TestExpirationHandler(t *testing.T) { @@ -76,9 +77,106 @@ func TestExpirationHandler(t *testing.T) { // step: if closure so we need to get the handler each time proxy.expirationHandler(cx) // step: check the content result - if cx.Writer.Status() != c.HTTPCode { - t.Errorf("test case %d should have recieved: %d, but got %d", i, c.HTTPCode, cx.Writer.Status()) + assert.Equal(t, c.HTTPCode, cx.Writer.Status(), "test case %d should have recieved: %d, but got %d", i, + c.HTTPCode, cx.Writer.Status()) + } +} + +func TestAuthorizationURL(t *testing.T) { + _, _, u := newTestProxyService(t, nil) + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return fmt.Errorf("no redirect") + }, + } + cs := []struct { + URL string + ExpectedURL string + ExpectedCode int + }{ + { + URL: "/", + ExpectedCode: http.StatusNotFound, + }, + { + URL: "/admin", + ExpectedURL: "/oauth/authorize?state=L2FkbWlu", + ExpectedCode: http.StatusTemporaryRedirect, + }, + { + URL: "/admin/test", + ExpectedURL: "/oauth/authorize?state=L2FkbWluL3Rlc3Q=", + ExpectedCode: http.StatusTemporaryRedirect, + }, + { + URL: "/admin/../", + ExpectedURL: "/oauth/authorize?state=L2FkbWluLy4uLw==", + ExpectedCode: http.StatusTemporaryRedirect, + }, + { + URL: "/admin?test=yes&test1=test", + ExpectedURL: "/oauth/authorize?state=L2FkbWluP3Rlc3Q9eWVzJnRlc3QxPXRlc3Q=", + ExpectedCode: http.StatusTemporaryRedirect, + }, + } + for i, x := range cs { + resp, _ := client.Get(u + x.URL) + assert.Equal(t, x.ExpectedCode, resp.StatusCode, "case %d, expect: %v, got: %s", i, x.ExpectedCode, resp.StatusCode) + assert.Equal(t, x.ExpectedURL, resp.Header.Get("Location"), "case %d, expect: %v, got: %s", i, x.ExpectedURL, resp.Header.Get("Location")) + } +} + +func TestCallbackURL(t *testing.T) { + _, _, u := newTestProxyService(t, nil) + + cs := []struct { + URL string + ExpectedURL string + }{ + { + URL: "/oauth/authorize?state=L2FkbWlu", + ExpectedURL: "/admin", + }, + { + URL: "/oauth/authorize", + ExpectedURL: "/", + }, + { + URL: "/oauth/authorize?state=L2FkbWluL3Rlc3QxP3Rlc3QxJmhlbGxv", + ExpectedURL: "/admin/test1?test1&hello", + }, + } + for i, x := range cs { + // step: call the authorization endpoint + req, err := http.NewRequest("GET", u+x.URL, nil) + if err != nil { + continue + } + resp, err := http.DefaultTransport.RoundTrip(req) + if !assert.NoError(t, err, "case %d, should not have failed", i) { + continue + } + openIDURL := resp.Header.Get("Location") + if !assert.NotEmpty(t, openIDURL, "case %d, the open id redirection url is empty", i) { + continue } + req, _ = http.NewRequest("GET", openIDURL, nil) + resp, err = http.DefaultTransport.RoundTrip(req) + if !assert.NoError(t, err, "case %d, should not have failed calling the opend id url", i) { + continue + } + callbackURL := resp.Header.Get("Location") + if !assert.NotEmpty(t, callbackURL, "case %d, should have recieved a callback url", i) { + continue + } + // step: call the callback url + req, _ = http.NewRequest("GET", callbackURL, nil) + resp, err = http.DefaultTransport.RoundTrip(req) + if !assert.NoError(t, err, "case %d, unable to call the callback url", i) { + continue + } + // step: check the callback location is as expected + assert.Contains(t, resp.Header.Get("Location"), x.ExpectedURL) } } @@ -86,7 +184,7 @@ func TestHealthHandler(t *testing.T) { proxy := newFakeKeycloakProxy(t) context := newFakeGinContext("GET", healthURL) proxy.healthHandler(context) - if context.Writer.Status() != http.StatusOK { - t.Errorf("we should have recieved a 200 response") - } + assert.Equal(t, http.StatusOK, context.Writer.Status()) + assert.NotEmpty(t, context.Writer.Header().Get(versionHeader)) + assert.Equal(t, version, context.Writer.Header().Get(versionHeader)) } diff --git a/main.go b/main.go index e0d47bc..a6f15fc 100644 --- a/main.go +++ b/main.go @@ -32,42 +32,44 @@ func main() { kc.Version = version kc.Author = author kc.Email = email + kc.UsageText = "keycloak-proxy [options]" kc.Flags = getOptions() - kc.Action = func(cx *cli.Context) { + kc.Action = func(cx *cli.Context) error { // step: do we have a configuration file? if filename := cx.String("config"); filename != "" { if err := readConfigFile(filename, config); err != nil { - printUsage(fmt.Sprintf("unable to read the configuration file: %s, error: %s", filename, err.Error())) + return printError("unable to read the configuration file: %s, error: %s", filename, err.Error()) } } // step: parse the command line options if err := readOptions(cx, config); err != nil { - printUsage(err.Error()) + return printError(err.Error()) } // step: validate the configuration if err := config.isValid(); err != nil { - printUsage(err.Error()) + return printError(err.Error()) } // step: create the proxy proxy, err := newProxy(config) if err != nil { - printUsage(err.Error()) + return printError(err.Error()) } // step: start the service if err := proxy.Run(); err != nil { - printUsage(err.Error()) + return printError(err.Error()) } // step: setup the termination signals signalChannel := make(chan os.Signal) signal.Notify(signalChannel, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) <-signalChannel + + return nil } kc.Run(os.Args) } // printUsage display the command line usage and error -func printUsage(message string) { - fmt.Fprintf(os.Stderr, "\n[error] %s\n", message) - os.Exit(1) +func printError(message string, args ...interface{}) *cli.ExitError { + return cli.NewExitError(fmt.Sprintf("[error] "+message, args...), 1) } diff --git a/middleware.go b/middleware.go index 030052a..152bff0 100644 --- a/middleware.go +++ b/middleware.go @@ -39,6 +39,7 @@ func (r *oauthProxy) loggingHandler() gin.HandlerFunc { return func(cx *gin.Context) { start := time.Now() cx.Next() + latency := time.Now().Sub(start) log.WithFields(log.Fields{ @@ -56,9 +57,6 @@ func (r *oauthProxy) loggingHandler() gin.HandlerFunc { // entryPointHandler checks to see if the request requires authentication // func (r oauthProxy) entryPointHandler() gin.HandlerFunc { - // step: create the proxy handler - proxy := r.proxyHandler() - return func(cx *gin.Context) { if strings.HasPrefix(cx.Request.URL.Path, oauthURL) { cx.Next() @@ -78,18 +76,8 @@ func (r oauthProxy) entryPointHandler() gin.HandlerFunc { break } } - // step: pass into the authentication and admission handlers + // step: pass into the authentication, admission and proxy handlers cx.Next() - - // step: add a custom headers to the request - for k, v := range r.config.Headers { - cx.Request.Header.Add(k, v) - } - - // step: check the request has not been aborted and if not, proxy request - if !cx.IsAborted() { - proxy(cx) - } } } @@ -110,7 +98,7 @@ func (r *oauthProxy) authenticationHandler() gin.HandlerFunc { if err != nil { log.WithFields(log.Fields{ "error": err.Error(), - }).Errorf("failed to get session, redirecting for authorization") + }).Errorf("no session found in request, redirecting for authorization") r.redirectToAuthorization(cx) return @@ -193,7 +181,7 @@ func (r *oauthProxy) authenticationHandler() gin.HandlerFunc { }).Infof("found a refresh token, attempting to refresh access token for user: %s", user.email) // step: attempts to refresh the access token - token, expires, err := refreshToken(r.client, rToken) + token, expires, err := getRefreshedToken(r.client, rToken) if err != nil { // step: has the refresh token expired switch err { @@ -278,7 +266,7 @@ func (r *oauthProxy) admissionHandler() gin.HandlerFunc { user := uc.(*userContext) // step: check the audience for the token is us - if !user.isAudience(r.config.ClientID) { + if r.config.ClientID != "" && !user.isAudience(r.config.ClientID) { log.WithFields(log.Fields{ "username": user.name, "expired_on": user.expiresAt.String(), @@ -384,27 +372,71 @@ func (r *oauthProxy) crossOriginResourceHandler(c CORS) gin.HandlerFunc { } } +// +// upstreamHeadersHandler is responsible for add the authentication headers for the upstream +// +func (r *oauthProxy) upstreamHeadersHandler(custom []string) gin.HandlerFunc { + // step: we don't wanna do this every time, quicker to perform once + customClaims := make(map[string]string) + for _, x := range custom { + customClaims[x] = fmt.Sprintf("X-Auth-%s", toHeader(x)) + } + + return func(cx *gin.Context) { + // step: add a custom headers to the request + for k, v := range r.config.Headers { + cx.Request.Header.Add(k, v) + } + + // step: retrieve the user context if any + if user, found := cx.Get(userContextName); found { + id := user.(*userContext) + cx.Request.Header.Add("X-Auth-Userid", id.name) + cx.Request.Header.Add("X-Auth-Subject", id.id) + cx.Request.Header.Add("X-Auth-Username", id.name) + cx.Request.Header.Add("X-Auth-Email", id.email) + cx.Request.Header.Add("X-Auth-ExpiresIn", id.expiresAt.String()) + cx.Request.Header.Add("X-Auth-Token", id.token.Encode()) + cx.Request.Header.Add("X-Auth-Roles", strings.Join(id.roles, ",")) + cx.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", id.token.Encode())) + + // step: inject any custom claims + for claim, header := range customClaims { + if claim, found := id.claims[claim]; found { + cx.Request.Header.Add(header, fmt.Sprintf("%v", claim)) + } + } + } + // step: add the default headers + cx.Request.Header.Add("X-Forwarded-For", cx.Request.RemoteAddr) + cx.Request.Header.Set("X-Forwarded-Agent", prog) + cx.Request.Header.Set("X-Forwarded-Host", cx.Request.Host) + } +} + // // securityHandler performs numerous security checks on the request // func (r *oauthProxy) securityHandler() gin.HandlerFunc { // step: create the security options secure := secure.New(secure.Options{ - AllowedHosts: r.config.Hostnames, - BrowserXssFilter: true, - ContentTypeNosniff: true, - FrameDeny: true, - STSIncludeSubdomains: true, - STSSeconds: 31536000, + AllowedHosts: r.config.Hostnames, + BrowserXssFilter: true, + ContentTypeNosniff: true, + FrameDeny: true, }) return func(cx *gin.Context) { // step: pass through the security middleware if err := secure.Process(cx.Writer, cx.Request); err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed security middleware") + log.WithFields(log.Fields{ + "error": err.Error(), + }).Errorf("failed security middleware") + cx.Abort() return } + // step: permit the request to continue cx.Next() } diff --git a/middleware_test.go b/middleware_test.go index 5b263c7..26667e1 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -22,6 +22,7 @@ import ( "github.com/coreos/go-oidc/jose" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" ) func TestEntrypointHandlerSecure(t *testing.T) { @@ -176,30 +177,33 @@ func TestEntrypointHandler(t *testing.T) { } } +func TestAuthenticationHandler(t *testing.T) { + +} + func TestSecurityHandler(t *testing.T) { kc := newFakeKeycloakProxy(t) handler := kc.securityHandler() context := newFakeGinContext("GET", "/") handler(context) - if context.Writer.Status() != http.StatusOK { - t.Errorf("we should have received a 200") - } + + assert.Equal(t, http.StatusOK, context.Writer.Status(), + "we should have received a 200 not %d", context.Writer.Status()) kc = newFakeKeycloakProxy(t) kc.config.Hostnames = []string{"127.0.0.1"} handler = kc.securityHandler() handler(context) - if context.Writer.Status() != http.StatusOK { - t.Errorf("we should have received a 200 not %d", context.Writer.Status()) - } + assert.Equal(t, http.StatusOK, context.Writer.Status(), + "we should have received a 200 not %d", context.Writer.Status()) kc = newFakeKeycloakProxy(t) kc.config.Hostnames = []string{"127.0.0.2"} handler = kc.securityHandler() handler(context) - if context.Writer.Status() != http.StatusInternalServerError { - t.Errorf("we should have received a 500 not %d", context.Writer.Status()) - } + + assert.Equal(t, http.StatusInternalServerError, context.Writer.Status(), + "we should have received a 500 not %d", context.Writer.Status()) } func TestCrossSiteHandler(t *testing.T) { @@ -248,6 +252,74 @@ func TestCrossSiteHandler(t *testing.T) { } } +func TestCustomHeadersHandler(t *testing.T) { + p := newFakeKeycloakProxy(t) + + cases := []struct { + Identity *userContext + CustomClaims []string + Expected http.Header + }{ + { + Expected: http.Header{}, + }, + { + Identity: &userContext{ + id: "test-subject", + name: "rohith", + email: "gambol99@gmail.com", + }, + Expected: http.Header{ + "X-Auth-Subject": []string{"test-subject"}, + "X-Auth-Userid": []string{"rohith"}, + "X-Auth-Email": []string{"gambol99@gmail.com"}, + "X-Auth-Username": []string{"rohith"}, + }, + }, + { + + Identity: &userContext{ + roles: []string{"a", "b", "c"}, + }, + Expected: http.Header{ + "X-Auth-Roles": []string{"a,b,c"}, + }, + }, + { + CustomClaims: []string{"given_name", "family_name"}, + Identity: &userContext{ + claims: jose.Claims{ + "email": "gambol99@gmail.com", + "name": "Rohith Jayawardene", + "family_name": "Jayawardene", + "preferred_username": "rjayawardene", + "given_name": "Rohith", + }, + }, + Expected: http.Header{ + "X-Auth-Given-Name": []string{"Rohith"}, + "X-Auth-Family-Name": []string{"Jayawardene"}, + }, + }, + } + for i, x := range cases { + handler := p.upstreamHeadersHandler(x.CustomClaims) + context := newFakeGinContext("GET", "/nothing") + if x.Identity != nil { + context.Set(userContextName, x.Identity) + } + handler(context) + // step: and check we have all the headers + for k := range x.Expected { + assert.Equal(t, + x.Expected.Get(k), + context.Request.Header.Get(k), + "case %d, expected (%s: %s) got: (%s: %s)", + i, k, x.Expected.Get(k), k, context.Request.Header.Get(k)) + } + } +} + func TestAdmissionHandlerRoles(t *testing.T) { proxy := newFakeKeycloakProxyWithResources(t, []*Resource{ { @@ -341,9 +413,8 @@ func TestAdmissionHandlerRoles(t *testing.T) { c.Context.Set(userContextName, c.UserContext) handler(c.Context) - if c.Context.Writer.Status() != c.HTTPCode { - t.Errorf("test case %d should have recieved code: %d, got %d", i, c.HTTPCode, c.Context.Writer.Status()) - } + status := c.Context.Writer.Status() + assert.Equal(t, c.HTTPCode, status, "test case %d should have recieved code: %d, got %d", i, c.HTTPCode, status) } } @@ -453,9 +524,7 @@ func TestAdmissionHandlerClaims(t *testing.T) { handler(c.Context) c.Context.Writer.WriteHeaderNow() - - if c.Context.Writer.Status() != c.HTTPCode { - t.Errorf("test case %d should have recieved code: %d, got %d", i, c.HTTPCode, c.Context.Writer.Status()) - } + status := c.Context.Writer.Status() + assert.Equal(t, c.HTTPCode, status, "test case %d should have recieved code: %d, got %d", i, c.HTTPCode, status) } } diff --git a/oauth.go b/oauth.go index bafff75..d76d16d 100644 --- a/oauth.go +++ b/oauth.go @@ -41,9 +41,9 @@ func verifyToken(client *oidc.Client, token jose.JWT) error { } // -// refreshToken attempts to refresh the access token, returning the parsed token and the time it expires or a error +// getRefreshedToken attempts to refresh the access token, returning the parsed token and the time it expires or a error // -func refreshToken(client *oidc.Client, t string) (jose.JWT, time.Time, error) { +func getRefreshedToken(client *oidc.Client, t string) (jose.JWT, time.Time, error) { response, err := getToken(client, oauth2.GrantTypeRefreshToken, t) if err != nil { if strings.Contains(err.Error(), "token expired") { diff --git a/oauth_test.go b/oauth_test.go index 0825256..5bc8079 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -16,15 +16,66 @@ limitations under the License. package main import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "math/rand" "net/http" + "net/http/httptest" + "net/url" + "sync" "testing" + "time" + "github.com/coreos/go-oidc/jose" "github.com/gin-gonic/gin" ) type fakeOAuthServer struct { + sync.Mutex + // the location of the service + location *url.URL + // the private key + privateKey *rsa.PrivateKey + // the jwk key + key jose.JWK + // the signer + signer jose.Signer + // the claims + claims jose.Claims } +const fakePrivateKey = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAxMLIwi//YG6GPdYUPaV0PCXBEXjg2Xhf8/+NMB/1nt+wip4Z +rrAQf14PTCTlN4sbc2QGgRGtYikJBHQyfg/lCthrnasfdgL8c6SErr7Db524SqiD +m+/yKGI680LmBUIPkA0ikCJgb4cYVCiJ3HuYnFZUTsUAeK14SoXgcJdWulj0h6aP +iUIg5VrehuqAG+1RlK+GURgr9DbOmXJ/SYVKX/QArdBzjZ3BiQ1nxWWwBCLHfwv4 +8bWxPJIbDDnUNl6LolpSJkxg4qlp+0I/xgEveK1n1CMEA0mHuXFHeekKO72GDKAk +h89C9qVF2GmpDfo8G0D3lFm2m3jFNyMQTWkSkwIDAQABAoIBADwhOrD9chHKNQQY +tD7SnV70OrhYNH7BJrGuWztlyO4wdgcmobqc263Q1OP0Mohy3oS5ALPY7x+cYsEV +sYiM2vYhhWG9tfOenf/JOzMb4SXvES7fqLiy71IgEtvcieb5dUAUg4eAue/bXTf6 +24ahztWYHFOmKKq4eJZtq1U9KqfvlW1T4bg3mXV70huvfoMhYKwYryTOsQ5yiYCf +Yo4UGUBLfg3capIB5gxQdcqdDk+UTe9be7GQBj+3oziALb1nIhW7cpy0nw/r22A5 +pv1FbRqND2VYKjZCQyUbxnjty5eDIW7fKBIh0Ez9yZHqz4KHb1u/KlFm31NGZpMU +Xs/WN+ECgYEA+kcAi7fTUjagqov5a4Y595ptu2gmU4Cxr+EBhMWadJ0g7enCXjTI +HAFEsVi2awbSRswjxdIG533SiKg8NIXThMntfbTm+Kw3LSb0/++Zyr7OuKJczKvQ +KfjAHvqsV8yJqy1gApYqVOeU4/jMLDs2sMY59/IQNkUVHNncZO09aa8CgYEAyUKG +BUyvxSim++YPk3OznBFZhqJqR75GYtWSu91BgZk/YmgYM4ht2u5q96AIRbJ664Ks +v93varNfqyKN1BN3JPLw8Ph8uX/7k9lMmECXoNp2Tm3A54zlsHyNOGOSvU7axvUg +PfIhpvRZKA0QQK3c1CZDghs94siJeBSIpuzCsl0CgYEA8Z28LCZiT3tHbn5FY4Wo +zp36k7L/VRvn7niVg71U2IGc+bHzoAjqqwaab2/KY9apCAop+t9BJRi2OJHZ1Ybg +5dAfg30ygh2YAvIaEj8YxL+iSGMOndS82Ng5eW7dFMH0ohnjF3wrD96mQdO+IHFl +4hDsg67f8dSNhlXYzGKwKCcCgYEAlAsrKprOcOkGbCU/L+fcJuFcSX0PUNbWT71q +wmZu2TYxOeH4a2/f3zuh06UUcLBpWvQ0vq4yfvqTVP+F9IqdCcDrG1at6IYMOSWP +AjABWYFZpTd2vt0V2EzGVMRqHHb014VYwjhqKLV1H9D8M5ew6R18ayg+zaNV+86e +9qsSTMECgYEA322XUN8yUBTTWBkXY7ipzTHSWkxMuj1Pa0gtBd6Qqqu3v7qI+jMZ +hlWS2akhJ+3e7f3+KCslG8YMItld4VvAK0eHKQbQM/onav/+/iiR6C2oRBm3OwqO +Ka0WPQGKjQJhZRtqDAT3sfnrEEUa34+MkXQeKFCu6Yi0dRFic4iqOYU= +-----END RSA PRIVATE KEY----- +` + type fakeDiscoveryResponse struct { AuthorizationEndpoint string `json:"authorization_endpoint"` EndSessionEndpoint string `json:"end_session_endpoint"` @@ -41,51 +92,96 @@ type fakeDiscoveryResponse struct { UserinfoEndpoint string `json:"userinfo_endpoint"` } -type fakeKeysResponse struct { - Keys []fakeKeyResponse `json:"keys"` -} +var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +// +// newFakeOAuthServer simulates a oauth service +// +func newFakeOAuthServer(t *testing.T) *fakeOAuthServer { + // step: load the private key + block, _ := pem.Decode([]byte(fakePrivateKey)) + // step: parse the private key + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + t.Fatalf("failed to parse the private key, error: %s", err) + } -type fakeKeyResponse struct { - Alg string `json:"alg"` - E string `json:"e"` - Kid string `json:"kid"` - Kty string `json:"kty"` - N string `json:"n"` - Use string `json:"use"` -} + service := &fakeOAuthServer{ + claims: jose.Claims{ + "jti": "4ee75b8e-3ee6-4382-92d4-3390b4b4937b", + "exp": int(time.Now().Add(time.Duration(10) * time.Hour).Unix()), + "nbf": 0, + "iat": float64(1450372669), + "aud": "test", + "sub": "1e11e539-8256-4b3b-bda8-cc0d56cddb48", + "typ": "Bearer", + "azp": "clientid", + "session_state": "98f4c3d2-1b8c-4932-b8c4-92ec0ea7e195", + "client_session": "f0105893-369a-46bc-9661-ad8c747b1a69", + "email": "gambol99@gmail.com", + "name": "Rohith Jayawardene", + "family_name": "Jayawardene", + "preferred_username": "rjayawardene", + "given_name": "Rohith", + }, + privateKey: privateKey, + key: jose.JWK{ + ID: "test-kid", + Type: "RSA", + Alg: "RS256", + Use: "sig", + Exponent: privateKey.PublicKey.E, + Modulus: privateKey.PublicKey.N, + Secret: block.Bytes, + }, + signer: jose.NewSignerRSA("test-kid", *privateKey), + } -const ( - fakePublicKey = "ibGNjo_opyEGbeDP3cctILhSW-sGKtG67hCZXxvHx-wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5-FMbHth-TKZiEhm-3EBadc1qgkfnpinfpxCVqHHaF8mFLC5-k3JsINIR0FAmPN9trxryI_npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh_rrLKAs0AdUYwXGAslnYDBACiR8GNrb7Q" + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.GET("auth/realms/hod-test/.well-known/openid-configuration", service.discoveryHandler) + r.GET("auth/realms/hod-test/protocol/openid-connect/certs", service.keysHandler) + r.GET("auth/realms/hod-test/protocol/openid-connect/token", service.tokenHandler) + r.POST("auth/realms/hod-test/protocol/openid-connect/token", service.tokenHandler) + r.GET("auth/realms/hod-test/protocol/openid-connect/auth", service.authHandler) + + location, err := url.Parse(httptest.NewServer(r).URL) + if err != nil { + t.Fatalf("unable to create fake oauth service, error: %s", err) + } + service.location = location + service.claims["iss"] = service.getLocation() - oauthPublicKey = "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQAB" - oauthPrivateKey = "MIIEowIBAAKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQABAoIBAGtfMlSmMbUKErpiIZX+uFkYgti8p92CGLOF7CN3RU3H+PgfF1m4xHGqt4xw+2JyhgQFgTY4IiIN1QuPFzI82+6jDvMqBwEi2e0TGj4RKiOX9D8b/qSL9eUSfqQKPqnPZfBymM3sqe5yddY7KVZiMXEEBu1efhhTADluIraKQjYJKgQd0P3CgfqhuUWgCqGjPwIg0BkzXofR0bjdrq8d0ul8JLnT+9ho/x8rahEN/LTHHLIwb6IYUj8X10tDZWPDk2NE5wRIy18peSXYNTeGhY1ThF75ZOAH5c1qgi0ObE+dUSqzwcDWqNDPxFvg2x67KbcMaTO6u87/mGJfuO2ekz0CgYEAwpR+tZdafTzR+MLGg55mxsfVjAWGNxp0AMwWZVTpPx1I+VgdLsMkUY8LpY2Zt8l2yInIGEzYRBNFYPrM73bW5v0bleGl60I6j3KA/Ic6RUaweycbQgMxob5PCWrMm94Jib1bGAxNU1m0Jp9rzxGUzWw3TpSw6LHNLqokwMCKG/cCgYEAtSg1oqeCvvCrIdA6AulzzWR6x2Re/Iv8MYJ5X0fNPRBHSVhwsdb2nLfjMPmLesBOPm55O/LZDFtL8unpOUc+qT8QWKAjvI0/HtYf2sec3sP/dxCYYK18grK1cvD/UAUfiljM0gAsxZRT77VbpOIMCOi9YjHoyeRgCQtxB9CuZjsCgYEAsLNfehLvpwmjeK+QzRf9J4l0AQtHPiU0sUClGfKJOrqieWUuYzftdG9d2UMFFGTNDQIqhv7J6tBBUfeQQep+8BdshKj9Hu7u9TO7tRgsr5qpS71QwJrb6JFFfzzQgL+bk800u1r4obe1pNljcxD5O6+JbkATg81rknQKmkx/XzMCgYARnyqwesjuF+0dqeqqs9jO5vJGiQ3wVRGgI0f5K7vcL8Qvb0nvErEEh6Ky9eNKeoBh9E8YtMPGPu9BXt2P8801m2vUoyc2xSqZrkyE9Jve04P7KgMYjGerMwURfD3po8XwqDisSNYSFh6gF60ledOf3jvl3GL/mJZ66sEA+JyuVwKBgGwef1FWkDTeft6VFo2obHCh8Fc8rsV2rQ0twgmA00nmuckKr5MQgyMiz2YYWanmOS18xLgl7FzvyX56clj1MvRl9xnwhSudtE4fxg6R4rzwf3jaWtAkXEHet+mqVRJgI9m5Bn8E7nVVmjgRlogZsgYq2pF3nL1sgl3ti7gOVVL6" - oauthCertificate = "MIICnzCCAYcCBgFUPZAJhjANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDDAhob2QtdGVzdDAeFw0xNjA0MjIxMDQwMzBaFw0yNjA0MjIxMDQyMTBaMBMxETAPBgNVBAMMCGhvZC10ZXN0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAibGNjo/opyEGbeDP3cctILhSW+sGKtG67hCZXxvHx+wd6n2KUNIPgs2yn0nH8XFJmrMbxnCe5+FMbHth+TKZiEhm+3EBadc1qgkfnpinfpxCVqHHaF8mFLC5+k3JsINIR0FAmPN9trxryI/npHzkDyfMbml2h21AHboZ3IJON3SbS2S1HaKR5b58ER4cl669nest5ixaOQCAgWIGoO7mXx7pR1PX0VEdLMg498jZkSCcCbAty4wBtTlmyLKyLF5iYRJPgL1lYxGCUZd5VlfPVr0efLf1MLtQ4rCjXmjPMwWTlU0rsEIFh/rrLKAs0AdUYwXGAslnYDBACiR8GNrb7QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBFd9T/1s769tGhOMtUspP2tChKy5OWF50HkRVLny1nt12JeQUvuVSD3l7vN17hFRpMm1ktjVCTxBk5PRfPtpOcMCG2zgYbB73hIRYZaKG5X6/r2y3TllZ2UkZh0ndL+jrn1L4I2zxB5OAi3CDTxiFtjcEShAC9smjp04Omxwat53k8IxJLRgnpuC/TMbxUPHLNjuOHLLFeSN7095SuD+qzx0H7fT4sqW3+mAr7Q/kl2yq4vMXfLHt5KkOm7O5px5mRoGS4Asbkw5MQMgP618uQ9k7EQZx37jF2ol4Z7uLQWscePdWA66ajbxAtybCesNPa4uUrb1YVdx6MikWyZ0i7" -) + return service +} -func newFakeOAuthServer(t *testing.T) { - s := new(fakeOAuthServer) - r := gin.New() - r.GET("/auth/realms/hod-test/.well-known/openid-configuration", s.discoveryHandler) - r.GET("/auth/realms/hod-test/protocol/openid-connect/certs", s.keysHandler) - r.POST("/auth/realms/hod-test/protocol/openid-connect/token", s.tokenHandler) - r.POST("/auth/realms/hod-test/protocol/openid-connect/auth", s.authHandler) +func (r *fakeOAuthServer) getLocation() string { + return fmt.Sprintf("%s://%s/auth/realms/hod-test", r.location.Scheme, r.location.Host) +} - if err := r.Run("127.0.0.1:8080"); err != nil { - t.Fatalf("failed to start the fake oauth service, error: %s", err) +func (r *fakeOAuthServer) setUserRealmRoles(roles []string) *fakeOAuthServer { + r.claims["realm_access"] = map[string]interface{}{ + "roles": roles, } + return r } -func (r fakeOAuthServer) discoveryHandler(cx *gin.Context) { +func (r *fakeOAuthServer) setUserExpiration(duration time.Duration) *fakeOAuthServer { + r.claims["exp"] = time.Now().Add(duration).Second() + return r +} + +func (r *fakeOAuthServer) discoveryHandler(cx *gin.Context) { cx.JSON(http.StatusOK, fakeDiscoveryResponse{ IDTokenSigningAlgValuesSupported: []string{"RS256"}, - Issuer: "http://127.0.0.1:8080/auth/realms/hod-test", - AuthorizationEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/auth", - TokenEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/token", - RegistrationEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/clients-registrations/openid-connect", - TokenIntrospectionEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/token/introspect", - UserinfoEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/userinfo", - EndSessionEndpoint: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/logout", - JwksURI: "http://127.0.0.1:8080/auth/realms/hod-test/protocol/openid-connect/certs", + Issuer: fmt.Sprintf("http://%s/auth/realms/hod-test", r.location.Host), + AuthorizationEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/auth", r.location.Host), + TokenEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/token", r.location.Host), + RegistrationEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/clients-registrations/openid-connect", r.location.Host), + TokenIntrospectionEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/token/introspect", r.location.Host), + UserinfoEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/userinfo", r.location.Host), + EndSessionEndpoint: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/logout", r.location.Host), + JwksURI: fmt.Sprintf("http://%s/auth/realms/hod-test/protocol/openid-connect/certs", r.location.Host), GrantTypesSupported: []string{"authorization_code", "implicit", "refresh_token", "password", "client_credentials"}, ResponseModesSupported: []string{"query", "fragment", "form_post"}, ResponseTypesSupported: []string{"code", "none", "id_token", "token", "id_token token", "code id_token", "code token", "code id_token token"}, @@ -93,25 +189,48 @@ func (r fakeOAuthServer) discoveryHandler(cx *gin.Context) { }) } -func (r fakeOAuthServer) keysHandler(cx *gin.Context) { - cx.JSON(http.StatusOK, fakeKeysResponse{ - Keys: []fakeKeyResponse{ - { - Kid: "ing3Hnuj0ciqrHCOxt__-B53jzXcdD1n1iKbX3GsD9s", - Kty: "RSA", - Alg: "RS256", - Use: "sig", - N: fakePublicKey, - E: "AQAB", - }, - }, - }) +func (r *fakeOAuthServer) keysHandler(cx *gin.Context) { + cx.JSON(http.StatusOK, jose.JWKSet{Keys: []jose.JWK{r.key}}) } -func (r fakeOAuthServer) authHandler(cx *gin.Context) { +func (r *fakeOAuthServer) authHandler(cx *gin.Context) { + state := cx.Query("state") + redirect := cx.Query("redirect_uri") + + if redirect == "" { + cx.AbortWithStatus(http.StatusInternalServerError) + return + } + if state == "" { + state = "/" + } + // step: generate a random authentication code + redirectionURL := fmt.Sprintf("%s?state=%s&code=%s", redirect, state, getRandomString(32)) + cx.Redirect(http.StatusTemporaryRedirect, redirectionURL) } -func (r fakeOAuthServer) tokenHandler(cx *gin.Context) { +func (r *fakeOAuthServer) tokenHandler(cx *gin.Context) { + expiration := time.Now().Add(time.Duration(1) * time.Hour) + token, err := jose.NewSignedJWT(r.claims, r.signer) + if err != nil { + cx.AbortWithError(http.StatusInternalServerError, err) + return + } + + cx.JSON(http.StatusOK, tokenResponse{ + IDToken: token.Encode(), + AccessToken: token.Encode(), + RefreshToken: token.Encode(), + ExpiresIn: expiration.Second(), + }) +} + +func getRandomString(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) } diff --git a/server.go b/server.go index 8f8ce1b..a6fbd66 100644 --- a/server.go +++ b/server.go @@ -16,23 +16,24 @@ limitations under the License. package main import ( - "crypto/md5" "crypto/tls" "crypto/x509" - "encoding/hex" + "encoding/base64" "fmt" "io/ioutil" "net" "net/http" - "net/http/httputil" "net/url" "path" + "runtime" "strings" "time" log "github.com/Sirupsen/logrus" + "github.com/armon/go-proxyproto" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oidc" + "github.com/elazarl/goproxy" "github.com/gin-gonic/gin" ) @@ -60,112 +61,191 @@ type reverseProxy interface { func init() { // step: ensure all time is in UTC time.LoadLocation("UTC") + // step: set the core + runtime.GOMAXPROCS(runtime.NumCPU()) } // // newProxy create's a new proxy from configuration // -func newProxy(cfg *Config) (*oauthProxy, error) { +func newProxy(config *Config) (*oauthProxy, error) { var err error // step: set the logging level - if cfg.LogJSONFormat { + if config.LogJSONFormat { log.SetFormatter(&log.JSONFormatter{}) } - if cfg.Verbose { + if config.Verbose { log.SetLevel(log.DebugLevel) } log.Infof("starting %s, author: %s, version: %s, ", prog, author, version) - service := &oauthProxy{config: cfg} + service := &oauthProxy{config: config} // step: parse the upstream endpoint - service.endpoint, err = url.Parse(cfg.Upstream) + service.endpoint, err = url.Parse(config.Upstream) if err != nil { return nil, err } // step: initialize the store if any - if cfg.StoreURL != "" { - if service.store, err = newStorage(cfg.StoreURL); err != nil { + if config.StoreURL != "" { + if service.store, err = createStorage(config.StoreURL); err != nil { return nil, err } } - // step: initialize the reverse http proxy - service.upstream, err = service.setupReverseProxy(service.endpoint) - if err != nil { - return nil, err - } - // step: initialize the openid client - if !cfg.SkipTokenVerification { - service.client, service.provider, err = initializeOpenID(cfg) + if !config.SkipTokenVerification { + service.client, service.provider, err = createOpenIDClient(config) if err != nil { return nil, err } } else { - log.Infof("TESTING ONLY CONFIG - the verification of the token have been disabled") + log.Warnf("TESTING ONLY CONFIG - the verification of the token have been disabled") } - // step: initialize the gin router - service.router = gin.New() - - // step: load the templates - if err = service.setupTemplates(); err != nil { - return nil, err + if config.ClientID == "" && config.ClientSecret == "" { + log.Warnf("Note: client credentials are not set, depending on provider (confidential|public) you might be able to auth") } - // step: setup the gin router and add router - if err := service.setupRouter(); err != nil { - return nil, err + + // step: + switch config.EnableForwarding { + case true: + log.Infof("enabled forwarding proxy mode") + if err := createForwardingProxy(config, service); err != nil { + return nil, err + } + default: + if err := createReverseProxy(config, service); err != nil { + return nil, err + } } + + return service, nil +} + +// +// createReverseProxy creates a reverse proxy +// +func createReverseProxy(config *Config, service *oauthProxy) error { + log.Infof("enabled reverse proxy mode, upstream url: %s", config.Upstream) + // step: display the protected resources - for _, resource := range cfg.Resources { + for _, resource := range config.Resources { log.Infof("protecting resources under uri: %s", resource) } - for name, value := range cfg.MatchClaims { + for name, value := range config.MatchClaims { log.Infof("the token must container the claim: %s, required: %s", name, value) } - return service, nil + // step: initialize the reverse http proxy + if err := service.createUpstream(service.endpoint); err != nil { + return err + } + + // step: setup the gin router and add router + if err := service.createEndpoints(); err != nil { + return err + } + + // step: load the templates + if err := service.createTemplates(); err != nil { + return err + } + + return nil +} + +// +// createForwardingProxy creates a forwarding proxy +// +func createForwardingProxy(config *Config, service *oauthProxy) error { + // step: initialize the reverse http proxy + if err := service.createUpstream(service.endpoint); err != nil { + return err + } + + gin.SetMode(gin.ReleaseMode) + // step: enable debugging in verbose more + if config.Verbose { + gin.SetMode(gin.DebugMode) + } + engine := gin.New() + + // step: default to release mode, only go debug on verbose logging + engine.Use(gin.Recovery()) + service.router = engine + + // step: are we logging the traffic? + if config.LogRequests { + engine.Use(service.loggingHandler()) + } + + engine.Use(service.forwardProxyHandler()) + + return nil } // // Run starts the proxy service // -func (r *oauthProxy) Run() error { +func (r *oauthProxy) Run() (err error) { tlsConfig := &tls.Config{} // step: are we doing mutual tls? if r.config.TLSCaCertificate != "" { - log.Infof("enabling mutual tls, reading in the ca: %s", r.config.TLSCaCertificate) - + log.Infof("enabling mutual tls, reading in the signing ca: %s", r.config.TLSCaCertificate) caCert, err := ioutil.ReadFile(r.config.TLSCaCertificate) if err != nil { return err } + caCertPool := x509.NewCertPool() caCertPool.AppendCertsFromPEM(caCert) - tlsConfig.ClientCAs = caCertPool tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } - go func() { - log.Infof("keycloak proxy service starting on %s", r.config.Listen) + server := &http.Server{ + Addr: r.config.Listen, + Handler: r.router, + } + + // step: create the listener + listener, err := net.Listen("tcp", r.config.Listen) + if err != nil { + return err + } - var err error - if r.config.TLSCertificate == "" { - err = r.router.Run(r.config.Listen) - } else { - server := &http.Server{ - Addr: r.config.Listen, - Handler: r.router, - TLSConfig: tlsConfig, + // step: wrap the listen in a proxy protocol + if r.config.EnableProxyProtocol { + log.Infof("enabling the proxy protocol on listener: %s", r.config.Listen) + listener = &proxyproto.Listener{listener} + } + + if r.config.TLSCertificate != "" { + server.TLSConfig = tlsConfig + + config := cloneTLSConfig(server.TLSConfig) + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1"} + } + if len(config.Certificates) == 0 || r.config.TLSCertificate != "" || r.config.TLSPrivateKey != "" { + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(r.config.TLSCertificate, r.config.TLSPrivateKey) + if err != nil { + return err } - err = server.ListenAndServeTLS(r.config.TLSCertificate, r.config.TLSPrivateKey) } - if err != nil { + + listener = tls.NewListener(listener, config) + } + + go func() { + log.Infof("keycloak proxy service starting on %s", r.config.Listen) + if err = server.Serve(listener); err != nil { log.WithFields(log.Fields{ "error": err.Error(), }).Fatalf("failed to start the service") @@ -176,63 +256,19 @@ func (r *oauthProxy) Run() error { } // -// redirectToURL redirects the user and aborts the context -// -func (r *oauthProxy) redirectToURL(url string, cx *gin.Context) { - cx.Redirect(http.StatusTemporaryRedirect, url) - cx.Abort() -} - -// -// accessForbidden redirects the user to the forbidden page -// -func (r *oauthProxy) accessForbidden(cx *gin.Context) { - if r.config.hasCustomForbiddenPage() { - cx.HTML(http.StatusForbidden, path.Base(r.config.ForbiddenPage), r.config.TagData) - cx.Abort() - return - } - - cx.AbortWithStatus(http.StatusForbidden) -} - -// -// redirectToAuthorization redirects the user to authorization handler +// createUpstream create a reverse http proxy from the upstream // -func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) { - if r.config.NoRedirects { - cx.AbortWithStatus(http.StatusUnauthorized) - return - } - - // step: add a state referrer to the authorization page - authQuery := fmt.Sprintf("?state=%s", cx.Request.URL.String()) - - // step: if verification is switched off, we can't authorization - if r.config.SkipTokenVerification { - log.Errorf("refusing to redirection to authorization endpoint, skip token verification switched on") - - cx.AbortWithStatus(http.StatusForbidden) - return - } - - r.redirectToURL(oauthURL+authorizationURL+authQuery, cx) -} - -// -// setupReverseProxy create a reverse http proxy from the upstream -// -func (r *oauthProxy) setupReverseProxy(upstream *url.URL) (reverseProxy, error) { +func (r *oauthProxy) createUpstream(upstream *url.URL) error { // step: create the default dialer dialer := (&net.Dialer{ - KeepAlive: 10 * time.Second, - Timeout: 10 * time.Second, + KeepAlive: r.config.UpstreamKeepaliveTimeout, + Timeout: r.config.UpstreamTimeout, }).Dial // step: are we using a unix socket? if upstream.Scheme == "unix" { - log.Infof("using the unix domain socket: %s for upstream", upstream.Host) - socketPath := upstream.Host + log.Infof("using the unix domain socket: %s%s for upstream", upstream.Host, upstream.Path) + socketPath := fmt.Sprintf("%s%s", upstream.Host, upstream.Path) dialer = func(network, address string) (net.Conn, error) { return net.Dial("unix", socketPath) } @@ -240,11 +276,11 @@ func (r *oauthProxy) setupReverseProxy(upstream *url.URL) (reverseProxy, error) upstream.Host = "domain-sock" upstream.Scheme = "http" } - // step: create the reverse proxy - proxy := httputil.NewSingleHostReverseProxy(upstream) - // step: customize the http transport - proxy.Transport = &http.Transport{ + // step: create the forwarding proxy + proxy := goproxy.NewProxyHttpServer() + // step: update the tls configuration of the reverse proxy + proxy.Tr = &http.Transport{ Dial: dialer, TLSClientConfig: &tls.Config{ InsecureSkipVerify: r.config.SkipUpstreamTLSVerify, @@ -252,24 +288,35 @@ func (r *oauthProxy) setupReverseProxy(upstream *url.URL) (reverseProxy, error) DisableKeepAlives: !r.config.UpstreamKeepalives, } - return proxy, nil + r.upstream = proxy + + return nil } // -// setupRouter sets up the gin routing +// createEndpoints sets up the gin routing // -func (r oauthProxy) setupRouter() error { - r.router.Use(gin.Recovery()) +func (r *oauthProxy) createEndpoints() error { + gin.SetMode(gin.ReleaseMode) + if r.config.Verbose { + gin.SetMode(gin.DebugMode) + } + engine := gin.New() + engine.Use(gin.Recovery()) + // step: are we logging the traffic? if r.config.LogRequests { - r.router.Use(r.loggingHandler()) + engine.Use(r.loggingHandler()) } + // step: enabling the security filter? if r.config.EnableSecurityFilter { - r.router.Use(r.securityHandler()) + engine.Use(r.securityHandler()) } // step: add the routing - oauth := r.router.Group(oauthURL).Use(r.crossOriginResourceHandler(r.config.CrossOrigin)) + oauth := engine.Group(oauthURL).Use( + r.crossOriginResourceHandler(r.config.CrossOrigin), + ) { oauth.GET(authorizationURL, r.oauthAuthorizationHandler) oauth.GET(callbackURL, r.oauthCallbackHandler) @@ -280,15 +327,22 @@ func (r oauthProxy) setupRouter() error { oauth.POST(loginURL, r.loginHandler) } - r.router.Use(r.entryPointHandler(), r.authenticationHandler(), r.admissionHandler()) + engine.Use( + r.entryPointHandler(), + r.authenticationHandler(), + r.admissionHandler(), + r.upstreamHeadersHandler(r.config.AddClaims), + r.upstreamReverseProxyHandler()) + + r.router = engine return nil } // -// setupTemplates loads the custom template +// createTemplates loads the custom template // -func (r *oauthProxy) setupTemplates() error { +func (r *oauthProxy) createTemplates() error { var list []string if r.config.SignInPage != "" { @@ -354,7 +408,9 @@ func (r *oauthProxy) DeleteRefreshToken(token jose.JWT) error { return nil } +// // Close is used to close off any resources +// func (r *oauthProxy) CloseStore() error { if r.store != nil { return r.store.Close() @@ -363,7 +419,46 @@ func (r *oauthProxy) CloseStore() error { return nil } -func getHashKey(token *jose.JWT) string { - hash := md5.Sum([]byte(token.Encode())) - return hex.EncodeToString(hash[:]) +// +// accessForbidden redirects the user to the forbidden page +// +func (r *oauthProxy) accessForbidden(cx *gin.Context) { + if r.config.hasCustomForbiddenPage() { + cx.HTML(http.StatusForbidden, path.Base(r.config.ForbiddenPage), r.config.TagData) + cx.Abort() + return + } + + cx.AbortWithStatus(http.StatusForbidden) +} + +// +// redirectToURL redirects the user and aborts the context +// +func (r *oauthProxy) redirectToURL(url string, cx *gin.Context) { + cx.Redirect(http.StatusTemporaryRedirect, url) + cx.Abort() +} + +// +// redirectToAuthorization redirects the user to authorization handler +// +func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) { + if r.config.NoRedirects { + cx.AbortWithStatus(http.StatusUnauthorized) + return + } + + // step: add a state referrer to the authorization page + authQuery := fmt.Sprintf("?state=%s", base64.StdEncoding.EncodeToString([]byte(cx.Request.URL.RequestURI()))) + + // step: if verification is switched off, we can't authorization + if r.config.SkipTokenVerification { + log.Errorf("refusing to redirection to authorization endpoint, skip token verification switched on") + + cx.AbortWithStatus(http.StatusForbidden) + return + } + + r.redirectToURL(oauthURL+authorizationURL+authQuery, cx) } diff --git a/server_test.go b/server_test.go index 3487901..09e77f2 100644 --- a/server_test.go +++ b/server_test.go @@ -21,6 +21,7 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httptest" "net/url" "testing" @@ -56,16 +57,16 @@ func newFakeKeycloakProxyWithResources(t *testing.T, resources []*Resource) *oau func newFakeKeycloakConfig(t *testing.T) *Config { return &Config{ - DiscoveryURL: "127.0.0.1:", + DiscoveryURL: "127.0.0.1:8080", ClientID: fakeClientID, ClientSecret: fakeSecret, EncryptionKey: "AgXa7xRcoClDEU0ZDSH4X0XhL5Qy2Z2j", SkipTokenVerification: true, Scopes: []string{}, EnableRefreshTokens: false, - SecureCookie: true, - CookieAccessName: cookieAccessToken, - CookieRefreshName: cookieRefreshToken, + SecureCookie: false, + CookieAccessName: "kc-access", + CookieRefreshName: "kc-state", Resources: []*Resource{ { URL: fakeAdminRoleURL, @@ -112,14 +113,41 @@ func newFakeKeycloakProxy(t *testing.T) *oauthProxy { upstream: new(fakeReverseProxy), endpoint: &url.URL{Host: "127.0.0.1"}, } - kc.router = gin.New() gin.SetMode(gin.ReleaseMode) + kc.router = gin.New() // step: add the gin routing - kc.setupRouter() + kc.createEndpoints() return kc } +func newTestProxyService(t *testing.T, config *Config) (*oauthProxy, *fakeOAuthServer, string) { + auth := newFakeOAuthServer(t) + if config == nil { + config = newFakeKeycloakConfig(t) + } + config.LogRequests = true + config.SkipTokenVerification = false + config.DiscoveryURL = auth.getLocation() + config.Verbose = false + log.SetOutput(ioutil.Discard) + + proxy, err := newProxy(config) + if err != nil { + t.Fatalf("failed to create proxy service, error: %s", err) + } + proxy.upstream = new(fakeReverseProxy) + service := httptest.NewServer(proxy.router) + config.RedirectionURL = service.URL + + proxy.client, proxy.provider, err = createOpenIDClient(config) + if err != nil { + t.Fatalf("failed to recreate the openid client, error: %s", err) + } + + return proxy, auth, service.URL +} + func TestNewKeycloakProxy(t *testing.T) { proxy, err := newProxy(newFakeKeycloakConfig(t)) assert.NoError(t, err) @@ -161,9 +189,9 @@ func TestInitializeReverseProxy(t *testing.T) { proxy := newFakeKeycloakProxy(t) uri, _ := url.Parse("http://127.0.0.1:8080") - reverse, err := proxy.setupReverseProxy(uri) + err := proxy.createUpstream(uri) assert.NoError(t, err) - assert.NotNil(t, reverse) + assert.NotNil(t, proxy.router) } func TestRedirectURL(t *testing.T) { diff --git a/session_test.go b/session_test.go index 2569b4e..a868eb9 100644 --- a/session_test.go +++ b/session_test.go @@ -90,7 +90,7 @@ func TestGetRefreshTokenFromCookie(t *testing.T) { { Cookies: []*http.Cookie{ { - Name: cookieRefreshToken, + Name: "kc-state", Path: "/", Domain: "127.0.0.1", Value: "refresh_token", diff --git a/stores.go b/stores.go index 616253a..1b56284 100644 --- a/stores.go +++ b/stores.go @@ -20,8 +20,8 @@ import ( "net/url" ) -// newStorage creates the store client for use -func newStorage(location string) (storage, error) { +// createStorage creates the store client for use +func createStorage(location string) (storage, error) { var store storage var err error diff --git a/tests/gen_token.go b/tests/gen_token.go index 8e2c4f2..a64c0ee 100644 --- a/tests/gen_token.go +++ b/tests/gen_token.go @@ -59,7 +59,7 @@ func main() { Usage: "a series of keypair claims which should be added to the token", }, } - app.Action = func(cx *cli.Context) { + app.Action = func(cx *cli.Context) error { header := jose.JOSEHeader{ "alg": "RS256", diff --git a/util_test.go b/util_test.go index dc1ef1f..e136dc5 100644 --- a/util_test.go +++ b/util_test.go @@ -28,6 +28,14 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCreateOpenIDClient(t *testing.T) { + client, _, err := createOpenIDClient(&Config{ + DiscoveryURL: newFakeOAuthServer(t).getLocation(), + }) + assert.NoError(t, err) + assert.NotNil(t, client) +} + func TestDecodeKeyPairs(t *testing.T) { testCases := []struct { List []string @@ -194,6 +202,20 @@ func TestContainedIn(t *testing.T) { assert.True(t, containedIn("1", []string{"1", "2", "3", "4"})) } +func TestDialAddress(t *testing.T) { + assert.Equal(t, dialAddress(getFakeURL("http://127.0.0.1")), "127.0.0.1:80") + assert.Equal(t, dialAddress(getFakeURL("https://127.0.0.1")), "127.0.0.1:443") + assert.Equal(t, dialAddress(getFakeURL("http://127.0.0.1:8080")), "127.0.0.1:8080") +} + +func TestIsUpgradedConnection(t *testing.T) { + header := http.Header{} + header.Add(headerUpgrade, "") + assert.False(t, isUpgradedConnection(&http.Request{Header: header})) + header.Set(headerUpgrade, "set") + assert.True(t, isUpgradedConnection(&http.Request{Header: header})) +} + func TestFileExists(t *testing.T) { if fileExists("no_such_file_exsit_32323232") { t.Errorf("we should have received false") @@ -209,20 +231,6 @@ func TestFileExists(t *testing.T) { } } -func TestDialAddress(t *testing.T) { - assert.Equal(t, dialAddress(getFakeURL("http://127.0.0.1")), "127.0.0.1:80") - assert.Equal(t, dialAddress(getFakeURL("https://127.0.0.1")), "127.0.0.1:443") - assert.Equal(t, dialAddress(getFakeURL("http://127.0.0.1:8080")), "127.0.0.1:8080") -} - -func TestIsUpgradedConnection(t *testing.T) { - header := http.Header{} - header.Add(headerUpgrade, "") - assert.False(t, isUpgradedConnection(&http.Request{Header: header})) - header.Set(headerUpgrade, "set") - assert.True(t, isUpgradedConnection(&http.Request{Header: header})) -} - func TestToHeader(t *testing.T) { cases := []struct { Word string diff --git a/utils.go b/utils.go index 7a8e085..e4421de 100644 --- a/utils.go +++ b/utils.go @@ -18,9 +18,11 @@ package main import ( "crypto/aes" "crypto/cipher" + "crypto/md5" "crypto/rand" "crypto/tls" "encoding/base64" + "encoding/hex" "fmt" "io" "net" @@ -35,6 +37,7 @@ import ( "unicode/utf8" log "github.com/Sirupsen/logrus" + "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oidc" "github.com/gin-gonic/gin" ) @@ -122,9 +125,9 @@ func decodeText(state, key string) (string, error) { return string(encoded), nil } -// initializeOpenID initializes the openID configuration, note: the redirection url is deliberately left blank +// createOpenIDClient initializes the openID configuration, note: the redirection url is deliberately left blank // in order to retrieve it from the host header on request -func initializeOpenID(cfg *Config) (*oidc.Client, oidc.ProviderConfig, error) { +func createOpenIDClient(cfg *Config) (*oidc.Client, oidc.ProviderConfig, error) { var err error var providerConfig oidc.ProviderConfig @@ -139,7 +142,7 @@ func initializeOpenID(cfg *Config) (*oidc.Client, oidc.ProviderConfig, error) { if err == nil { goto GOT_CONFIG } - log.Infof("failed to get provider configuration from discovery url: %s, %s", cfg.DiscoveryURL, err) + log.Warnf("failed to get provider configuration from discovery url: %s, %s", cfg.DiscoveryURL, err) time.Sleep(time.Second * 3) } @@ -183,25 +186,40 @@ func decodeKeyPairs(list []string) (map[string]string, error) { } // -// tryDialEndpoint dials the upstream endpoint via plain +// isValidMethod ensure this is a valid http method type // -func tryDialEndpoint(location *url.URL) (net.Conn, error) { - switch dialAddress := dialAddress(location); location.Scheme { - case "http": - return net.Dial("tcp", dialAddress) - default: - return tls.Dial("tcp", dialAddress, &tls.Config{ - Rand: rand.Reader, - InsecureSkipVerify: true, - }) - } +func isValidMethod(method string) bool { + return httpMethodRegex.MatchString(method) } // -// isValidMethod ensure this is a valid http method type +// cloneTLSConfig clones the tls configuration // -func isValidMethod(method string) bool { - return httpMethodRegex.MatchString(method) +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + SessionTicketsDisabled: cfg.SessionTicketsDisabled, + SessionTicketKey: cfg.SessionTicketKey, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } } // @@ -244,33 +262,18 @@ func containedIn(value string, list []string) bool { } // -// dialAddress extracts the dial address from the url -// -func dialAddress(location *url.URL) string { - items := strings.Split(location.Host, ":") - if len(items) != 2 { - switch location.Scheme { - case "http": - return location.Host + ":80" - default: - return location.Host + ":443" - } - } - - return location.Host -} - -// -// findCookie looks for a cookie in a list of cookies +// tryDialEndpoint dials the upstream endpoint via plain // -func findCookie(name string, cookies []*http.Cookie) *http.Cookie { - for _, cookie := range cookies { - if cookie.Name == name { - return cookie - } +func tryDialEndpoint(location *url.URL) (net.Conn, error) { + switch dialAddress := dialAddress(location); location.Scheme { + case "http": + return net.Dial("tcp", dialAddress) + default: + return tls.Dial("tcp", dialAddress, &tls.Config{ + Rand: rand.Reader, + InsecureSkipVerify: true, + }) } - - return nil } // @@ -330,6 +333,36 @@ func tryUpdateConnection(cx *gin.Context, endpoint *url.URL) error { return nil } +// +// dialAddress extracts the dial address from the url +// +func dialAddress(location *url.URL) string { + items := strings.Split(location.Host, ":") + if len(items) != 2 { + switch location.Scheme { + case "http": + return location.Host + ":80" + default: + return location.Host + ":443" + } + } + + return location.Host +} + +// +// findCookie looks for a cookie in a list of cookies +// +func findCookie(name string, cookies []*http.Cookie) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + + return nil +} + // // toHeader is a helper method to play nice in the headers // @@ -366,3 +399,11 @@ func mergeMaps(source, dest map[string]string) map[string]string { return dest } + +// +// getHashKey returns a hash of the encodes jwt token +// +func getHashKey(token *jose.JWT) string { + hash := md5.Sum([]byte(token.Encode())) + return hex.EncodeToString(hash[:]) +} -- GitLab