Skip to content
Snippets Groups Projects
Select Git revision
  • main default protected
1 result

ssh.go

Blame
  • ssh.go 3.12 KiB
    package main
    
    import (
    	"bufio"
    	"fmt"
    	"golang.org/x/crypto/ssh"
    	"golang.org/x/sync/errgroup"
    	"io"
    	"log"
    	"os"
    	"path"
    	"time"
    )
    
    type SshConnection struct {
    	Username string
    	Password string
    	Hostname string
    	Port     int
    	Timeout  time.Duration
    }
    
    func openSshClient(connection *SshConnection) (*ssh.Client, error) {
    	conf := &ssh.ClientConfig{
    		User:            connection.Username,
    		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
    		Auth: []ssh.AuthMethod{
    			ssh.Password(connection.Password),
    		},
    		Timeout: connection.Timeout,
    	}
    
    	log.Println("opening ssh connection")
    	var conn *ssh.Client
    	for attempt := 0; attempt < 20; attempt++ {
    		if conn != nil {
    			if err := conn.Close(); err != nil {
    				return nil, fmt.Errorf("could not connect to server via ssh %w", err)
    			}
    		}
    		conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", connection.Hostname, connection.Port), conf)
    		if err == nil {
    			return conn, nil
    		}
    		log.Println("waiting on ssh session")
    		time.Sleep(5 * time.Second)
    	}
    	return nil, fmt.Errorf("could not connect to server via ssh")
    }
    
    func openSshSession(conn *ssh.Client) (*ssh.Session, error) {
    	log.Println("opening ssh session")
    	var session *ssh.Session
    	session, err := conn.NewSession()
    	if err != nil {
    		return nil, fmt.Errorf("error opening ssh session %w", err)
    	}
    
    	log.Println("connecting ssh outputs to stdout")
    	if err := connectShellOutputs(session); err != nil {
    		return nil, fmt.Errorf("error connecting ssh outputs to stdout %w", err)
    	}
    
    	return session, nil
    }
    
    func copyFile(session *ssh.Session, file CopyFile) error {
    	handle, err := os.Open(file.Source)
    	if err != nil {
    		return err
    	}
    	defer handle.Close()
    
    	stat, err := handle.Stat()
    	if err != nil {
    		return err
    	}
    
    	filedir, filename := path.Split(file.Target)
    
    	group := errgroup.Group{}
    	group.Go(func() error {
    		stdin, _ := session.StdinPipe()
    		defer stdin.Close()
    		if _, err := fmt.Fprintf(stdin, "C0%d %d %s\n", file.Mode, stat.Size(), filename); err != nil {
    			return fmt.Errorf("error sending file: %w", err)
    		}
    		if _, err := io.Copy(stdin, handle); err != nil {
    			return fmt.Errorf("error sending file: %w", err)
    		}
    		if _, err := fmt.Fprint(stdin, "\x00"); err != nil {
    			return fmt.Errorf("error sending file: %w", err)
    		}
    		return nil
    	})
    	group.Go(func() error {
    		if err := session.Run(fmt.Sprintf("/usr/bin/scp -t %s", filedir)); err != nil {
    			return fmt.Errorf("error receiving file: %w", err)
    		}
    		return nil
    	})
    	return group.Wait()
    }
    
    func connectReader(label string, reader io.Reader) {
    	scanner := bufio.NewScanner(reader)
    	for {
    		if tkn := scanner.Scan(); tkn {
    			rcv := scanner.Bytes()
    			raw := make([]byte, len(rcv))
    			copy(raw, rcv)
    			log.Printf("%s: %s\n", label, string(raw))
    		} else {
    			if scanner.Err() != nil {
    				log.Printf("%s error: %s\n", label, scanner.Err())
    			}
    			return
    		}
    	}
    }
    
    func connectShellOutputs(session *ssh.Session) error {
    	var stdout, stderr io.Reader
    	var err error
    
    	stdout, err = session.StdoutPipe()
    	if err != nil {
    		return err
    	}
    
    	stderr, err = session.StderrPipe()
    	if err != nil {
    		return err
    	}
    
    	go connectReader("ssh stdout", stdout)
    	go connectReader("ssh stderr", stderr)
    
    	return nil
    }