package main

import (
	"bufio"
	"crypto/tls"
	"encoding/binary"
	"flag"
	"fmt"
	"net"
)

const ProtocolMagic uint32 = 0x42b33f00

const ProtocolFeatureTls uint8 = 0x01
const ProtocolFeatureCompression uint8 = 0x02

const ProtocolDatastream uint32 = 0x02

type connection struct {
	hostname   string
	socket     net.Conn
	tlsSocket  *tls.Conn
	readWriter *bufio.ReadWriter
	buffer     []byte
}

func makeConnection(address string, port int) connection {
	socket, err := net.Dial("tcp", fmt.Sprintf("%s:%d", address, port))
	if err != nil {
		panic(err.Error())
	}

	return connection{
		hostname: address,
		socket:   socket,
		readWriter: bufio.NewReadWriter(
			bufio.NewReader(socket),
			bufio.NewWriter(socket),
		),
		buffer: make([]byte, 4),
	}
}

func (c *connection) write(data uint32) {
	binary.BigEndian.PutUint32(c.buffer, data)
	_, err := c.readWriter.Write(c.buffer)
	if err != nil {
		panic(err.Error())
	}
}

func (c *connection) read(len int) []byte {
	buffer := make([]byte, len)
	_, err := c.readWriter.Read(buffer)
	if err != nil {
		panic(err.Error())
	}
	return buffer
}

func (c *connection) flush() {
	err := c.readWriter.Flush()
	if err != nil {
		panic(err.Error())
	}
}

func (c *connection) withTLS(verify bool) {
	config := &tls.Config{
		ServerName:         c.hostname,
		InsecureSkipVerify: !verify,
	}
	c.tlsSocket = tls.Client(c.socket, config)
	c.readWriter = bufio.NewReadWriter(
		bufio.NewReader(c.tlsSocket),
		bufio.NewWriter(c.tlsSocket),
	)
	err := c.tlsSocket.Handshake()
	if err != nil {
		panic(err.Error())
	}
}

func (c *connection) tlsState() *tls.ConnectionState {
	if c.tlsSocket != nil {
		state := c.tlsSocket.ConnectionState()
		return &state
	} else {
		return nil
	}
}

func (c *connection) close() {
	_ = c.readWriter.Flush()
	_ = c.tlsSocket.Close()
	_ = c.socket.Close()
}

type protocolInfo struct {
	flagTLS         bool
	flagCompression bool
	data            uint16
	version         uint8
}

func parseProtocolInfo(data []byte) protocolInfo {
	rawFeatures := data[0]
	rawData := data[1:3]
	rawVersion := data[3]

	return protocolInfo{
		rawFeatures&ProtocolFeatureTls != 0,
		rawFeatures&ProtocolFeatureCompression != 0,
		binary.BigEndian.Uint16(rawData),
		rawVersion,
	}
}

func main() {
	hostname := flag.String("hostname", "", "address of server to connect to")
	port := flag.Int("port", 4242, "port of server to connect to")
	flag.Parse()

	conn := makeConnection(*hostname, *port)
	conn.write(ProtocolMagic | uint32(ProtocolFeatureTls))
	supportedProtocols := []uint32{
		ProtocolDatastream,
	}
	for _, protocol := range supportedProtocols {
		conn.write(protocol)
	}
	conn.write(1 << 31)
	conn.flush()
	protocolInfo := parseProtocolInfo(conn.read(4))
	if protocolInfo.flagTLS {
		conn.withTLS(false)
	}
	if state := conn.tlsState(); state != nil {
		for _, cert := range state.PeerCertificates {
			println(cert.NotAfter.String())
		}
	}
	conn.close()
}