Select Git revision
-
Janne Mareike Koschinski authoredJanne Mareike Koschinski authored
ssh.go 3.12 KiB
package main
import (
"bufio"
"fmt"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"io"
"log"
"os"
"path"
"time"
)
type SshConnection struct {
Username string
Password string
Hostname string
Port int
Timeout time.Duration
}
func openSshClient(connection *SshConnection) (*ssh.Client, error) {
conf := &ssh.ClientConfig{
User: connection.Username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: []ssh.AuthMethod{
ssh.Password(connection.Password),
},
Timeout: connection.Timeout,
}
log.Println("opening ssh connection")
var conn *ssh.Client
for attempt := 0; attempt < 20; attempt++ {
if conn != nil {
if err := conn.Close(); err != nil {
return nil, fmt.Errorf("could not connect to server via ssh %w", err)
}
}
conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", connection.Hostname, connection.Port), conf)
if err == nil {
return conn, nil
}
log.Println("waiting on ssh session")
time.Sleep(5 * time.Second)
}
return nil, fmt.Errorf("could not connect to server via ssh")
}
func openSshSession(conn *ssh.Client) (*ssh.Session, error) {
log.Println("opening ssh session")
var session *ssh.Session
session, err := conn.NewSession()
if err != nil {
return nil, fmt.Errorf("error opening ssh session %w", err)
}
log.Println("connecting ssh outputs to stdout")
if err := connectShellOutputs(session); err != nil {
return nil, fmt.Errorf("error connecting ssh outputs to stdout %w", err)
}
return session, nil
}
func copyFile(session *ssh.Session, file CopyFile) error {
handle, err := os.Open(file.Source)
if err != nil {
return err
}
defer handle.Close()
stat, err := handle.Stat()
if err != nil {
return err
}
filedir, filename := path.Split(file.Target)
group := errgroup.Group{}
group.Go(func() error {
stdin, _ := session.StdinPipe()
defer stdin.Close()
if _, err := fmt.Fprintf(stdin, "C0%d %d %s\n", file.Mode, stat.Size(), filename); err != nil {
return fmt.Errorf("error sending file: %w", err)
}
if _, err := io.Copy(stdin, handle); err != nil {
return fmt.Errorf("error sending file: %w", err)
}
if _, err := fmt.Fprint(stdin, "\x00"); err != nil {
return fmt.Errorf("error sending file: %w", err)
}
return nil
})
group.Go(func() error {
if err := session.Run(fmt.Sprintf("/usr/bin/scp -t %s", filedir)); err != nil {
return fmt.Errorf("error receiving file: %w", err)
}
return nil
})
return group.Wait()
}
func connectReader(label string, reader io.Reader) {
scanner := bufio.NewScanner(reader)
for {
if tkn := scanner.Scan(); tkn {
rcv := scanner.Bytes()
raw := make([]byte, len(rcv))
copy(raw, rcv)
log.Printf("%s: %s\n", label, string(raw))
} else {
if scanner.Err() != nil {
log.Printf("%s error: %s\n", label, scanner.Err())
}
return
}
}
}
func connectShellOutputs(session *ssh.Session) error {
var stdout, stderr io.Reader
var err error
stdout, err = session.StdoutPipe()
if err != nil {
return err
}
stderr, err = session.StderrPipe()
if err != nil {
return err
}
go connectReader("ssh stdout", stdout)
go connectReader("ssh stderr", stderr)
return nil
}