package main

import (
	"bufio"
	"fmt"
	"golang.org/x/crypto/ssh"
	"golang.org/x/sync/errgroup"
	"io"
	"log"
	"os"
	"path"
	"time"
)

type SshConnection struct {
	Username string
	Password string
	Hostname string
	Port     int
	Timeout  time.Duration
}

func openSshClient(connection *SshConnection) (*ssh.Client, error) {
	conf := &ssh.ClientConfig{
		User:            connection.Username,
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
		Auth: []ssh.AuthMethod{
			ssh.Password(connection.Password),
		},
		Timeout: connection.Timeout,
	}

	log.Println("opening ssh connection")
	var conn *ssh.Client
	for attempt := 0; attempt < 20; attempt++ {
		if conn != nil {
			if err := conn.Close(); err != nil {
				return nil, fmt.Errorf("could not connect to server via ssh %w", err)
			}
		}
		conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", connection.Hostname, connection.Port), conf)
		if err == nil {
			return conn, nil
		}
		log.Println("waiting on ssh session")
		time.Sleep(5 * time.Second)
	}
	return nil, fmt.Errorf("could not connect to server via ssh")
}

func openSshSession(conn *ssh.Client) (*ssh.Session, error) {
	log.Println("opening ssh session")
	var session *ssh.Session
	session, err := conn.NewSession()
	if err != nil {
		return nil, fmt.Errorf("error opening ssh session %w", err)
	}

	log.Println("connecting ssh outputs to stdout")
	if err := connectShellOutputs(session); err != nil {
		return nil, fmt.Errorf("error connecting ssh outputs to stdout %w", err)
	}

	return session, nil
}

func copyFile(session *ssh.Session, file CopyFile) error {
	handle, err := os.Open(file.Source)
	if err != nil {
		return err
	}
	defer handle.Close()

	stat, err := handle.Stat()
	if err != nil {
		return err
	}

	filedir, filename := path.Split(file.Target)

	group := errgroup.Group{}
	group.Go(func() error {
		stdin, _ := session.StdinPipe()
		defer stdin.Close()
		if _, err := fmt.Fprintf(stdin, "C0%d %d %s\n", file.Mode, stat.Size(), filename); err != nil {
			return fmt.Errorf("error sending file: %w", err)
		}
		if _, err := io.Copy(stdin, handle); err != nil {
			return fmt.Errorf("error sending file: %w", err)
		}
		if _, err := fmt.Fprint(stdin, "\x00"); err != nil {
			return fmt.Errorf("error sending file: %w", err)
		}
		return nil
	})
	group.Go(func() error {
		if err := session.Run(fmt.Sprintf("/usr/bin/scp -t %s", filedir)); err != nil {
			return fmt.Errorf("error receiving file: %w", err)
		}
		return nil
	})
	return group.Wait()
}

func connectReader(label string, reader io.Reader) {
	scanner := bufio.NewScanner(reader)
	for {
		if tkn := scanner.Scan(); tkn {
			rcv := scanner.Bytes()
			raw := make([]byte, len(rcv))
			copy(raw, rcv)
			log.Printf("%s: %s\n", label, string(raw))
		} else {
			if scanner.Err() != nil {
				log.Printf("%s error: %s\n", label, scanner.Err())
			}
			return
		}
	}
}

func connectShellOutputs(session *ssh.Session) error {
	var stdout, stderr io.Reader
	var err error

	stdout, err = session.StdoutPipe()
	if err != nil {
		return err
	}

	stderr, err = session.StderrPipe()
	if err != nil {
		return err
	}

	go connectReader("ssh stdout", stdout)
	go connectReader("ssh stderr", stderr)

	return nil
}