From fdec9939c892eb3758a9932601fbdb3f228b8250 Mon Sep 17 00:00:00 2001
From: Rohith <gambol99@gmail.com>
Date: Sun, 7 Jan 2018 12:06:43 +0000
Subject: [PATCH] HTTP Hijack Check

Ensure the http.Hijack method is implemented by the http.ResponseWriter before attempting to grab the client connection
---
 utils.go | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/utils.go b/utils.go
index ac1b5ab..e90a78a 100644
--- a/utils.go
+++ b/utils.go
@@ -261,32 +261,34 @@ func transferBytes(src io.Reader, dest io.Writer, wg *sync.WaitGroup) (int64, er
 // tryUpdateConnection attempt to upgrade the connection to a http pdy stream
 func tryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint *url.URL) error {
 	// step: dial the endpoint
-	tlsConn, err := tryDialEndpoint(endpoint)
+	server, err := tryDialEndpoint(endpoint)
 	if err != nil {
 		return err
 	}
-	defer tlsConn.Close()
+	defer server.Close()
 
-	// step: we need to hijack the underlining client connection
-	clientConn, ok, err := writer.(http.Hijacker).Hijack()
-	if !ok {
+	// @check the the response writer implements the Hijack method
+	if _, ok := writer.(http.Hijacker); !ok {
 		return errors.New("writer does not implement http.Hijacker method")
 	}
+
+	// @step: get the client connection object
+	client, _, err := writer.(http.Hijacker).Hijack()
 	if err != nil {
 		return err
 	}
-	defer clientConn.Close()
+	defer client.Close()
 
 	// step: write the request to upstream
-	if err = req.Write(tlsConn); err != nil {
+	if err = req.Write(server); err != nil {
 		return err
 	}
 
-	// step: copy the date between client and upstream endpoint
+	// @step: copy the data between client and upstream endpoint
 	var wg sync.WaitGroup
 	wg.Add(2)
-	go transferBytes(tlsConn, clientConn, &wg)
-	go transferBytes(clientConn, tlsConn, &wg)
+	go transferBytes(server, client, &wg)
+	go transferBytes(client, server, &wg)
 	wg.Wait()
 
 	return nil
-- 
GitLab