package main import ( "bufio" "fmt" "golang.org/x/crypto/ssh" "golang.org/x/sync/errgroup" "io" "log" "os" "path" "time" ) type SshConnection struct { Username string Password string Hostname string Port int Timeout time.Duration } func openSshClient(connection *SshConnection) (*ssh.Client, error) { conf := &ssh.ClientConfig{ User: connection.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Auth: []ssh.AuthMethod{ ssh.Password(connection.Password), }, Timeout: connection.Timeout, } log.Println("opening ssh connection") var conn *ssh.Client for attempt := 0; attempt < 20; attempt++ { if conn != nil { if err := conn.Close(); err != nil { return nil, fmt.Errorf("could not connect to server via ssh %w", err) } } conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", connection.Hostname, connection.Port), conf) if err == nil { return conn, nil } log.Println("waiting on ssh session") time.Sleep(5 * time.Second) } return nil, fmt.Errorf("could not connect to server via ssh") } func openSshSession(conn *ssh.Client) (*ssh.Session, error) { log.Println("opening ssh session") var session *ssh.Session session, err := conn.NewSession() if err != nil { return nil, fmt.Errorf("error opening ssh session %w", err) } log.Println("connecting ssh outputs to stdout") if err := connectShellOutputs(session); err != nil { return nil, fmt.Errorf("error connecting ssh outputs to stdout %w", err) } return session, nil } func copyFile(session *ssh.Session, file CopyFile) error { handle, err := os.Open(file.Source) if err != nil { return err } defer handle.Close() stat, err := handle.Stat() if err != nil { return err } filedir, filename := path.Split(file.Target) group := errgroup.Group{} group.Go(func() error { stdin, _ := session.StdinPipe() defer stdin.Close() if _, err := fmt.Fprintf(stdin, "C0%d %d %s\n", file.Mode, stat.Size(), filename); err != nil { return fmt.Errorf("error sending file: %w", err) } if _, err := io.Copy(stdin, handle); err != nil { return fmt.Errorf("error sending file: %w", err) } if _, err := fmt.Fprint(stdin, "\x00"); err != nil { return fmt.Errorf("error sending file: %w", err) } return nil }) group.Go(func() error { if err := session.Run(fmt.Sprintf("/usr/bin/scp -t %s", filedir)); err != nil { return fmt.Errorf("error receiving file: %w", err) } return nil }) return group.Wait() } func connectReader(label string, reader io.Reader) { scanner := bufio.NewScanner(reader) for { if tkn := scanner.Scan(); tkn { rcv := scanner.Bytes() raw := make([]byte, len(rcv)) copy(raw, rcv) log.Printf("%s: %s\n", label, string(raw)) } else { if scanner.Err() != nil { log.Printf("%s error: %s\n", label, scanner.Err()) } return } } } func connectShellOutputs(session *ssh.Session) error { var stdout, stderr io.Reader var err error stdout, err = session.StdoutPipe() if err != nil { return err } stderr, err = session.StderrPipe() if err != nil { return err } go connectReader("ssh stdout", stdout) go connectReader("ssh stderr", stderr) return nil }