From 9c6eca4bda9496059f47edd8edf2ae281dbaa8c7 Mon Sep 17 00:00:00 2001 From: Janne Mareike Koschinski <janne@kuschku.de> Date: Fri, 28 Feb 2025 12:55:24 +0100 Subject: [PATCH] feat: token refresh --- api/joinroom.go | 4 +-- api/markread.go | 4 +-- api/refresh.go | 37 +++++++++++++++++++++++++ api/sendmessage.go | 4 +-- api/setpusher.go | 4 +-- api/token.go | 9 +++++++ main.go | 28 +++++++++++-------- matrixbot.go | 67 ++++++++++++++++++++++++++++++++++------------ 8 files changed, 121 insertions(+), 36 deletions(-) create mode 100644 api/refresh.go create mode 100644 api/token.go diff --git a/api/joinroom.go b/api/joinroom.go index 2767b39..5826d78 100644 --- a/api/joinroom.go +++ b/api/joinroom.go @@ -5,7 +5,7 @@ import ( "net/http" ) -func JoinRoom(userData LoginResponse, roomId string) error { +func JoinRoom(token Token, roomId string) error { request, err := http.NewRequest( http.MethodPost, fmt.Sprintf("https://matrix-client.matrix.org/_matrix/client/v3/rooms/%s/join", roomId), @@ -14,7 +14,7 @@ func JoinRoom(userData LoginResponse, roomId string) error { if err != nil { panic(err) } - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", userData.AccessToken)) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) _, err = http.DefaultClient.Do(request) return err } diff --git a/api/markread.go b/api/markread.go index 2a522b1..7bbdaa7 100644 --- a/api/markread.go +++ b/api/markread.go @@ -15,7 +15,7 @@ type ReadReceiptRequest struct { ReadPrivate string `json:"m.read.private"` } -func SetReadReceipt(userData LoginResponse, roomId string, messageId string) error { +func SetReadReceipt(token Token, roomId string, messageId string) error { body, err := json.Marshal(ReadReceiptRequest{ FullyRead: messageId, Read: messageId, @@ -32,7 +32,7 @@ func SetReadReceipt(userData LoginResponse, roomId string, messageId string) err if err != nil { panic(err) } - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", userData.AccessToken)) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) response, err := http.DefaultClient.Do(request) if err != nil { return err diff --git a/api/refresh.go b/api/refresh.go new file mode 100644 index 0000000..1dcc5b7 --- /dev/null +++ b/api/refresh.go @@ -0,0 +1,37 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" +) + +type RefreshRequest struct { + RefreshToken string `json:"refresh_token"` +} + +type RefreshResponse struct { + AccessToken string `json:"access_token"` + ExpiresInMs int `json:"expires_in_ms"` + RefreshToken string `json:"refresh_token"` +} + +func Refresh(refreshToken string) (RefreshResponse, error) { + body, err := json.Marshal(RefreshRequest{ + RefreshToken: refreshToken, + }) + if err != nil { + return RefreshResponse{}, err + } + resp, err := http.Post("https://matrix-client.matrix.org/_matrix/client/v3/refresh", "application/json", bytes.NewReader(body)) + if err != nil { + return RefreshResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return RefreshResponse{}, fmt.Errorf("request failed %d: %s", resp.StatusCode, resp.Status) + } + var responseBody RefreshResponse + err = json.NewDecoder(resp.Body).Decode(&responseBody) + return responseBody, err +} diff --git a/api/sendmessage.go b/api/sendmessage.go index 761d73d..70e480d 100644 --- a/api/sendmessage.go +++ b/api/sendmessage.go @@ -8,7 +8,7 @@ import ( "net/http" ) -func SendMessage(userData LoginResponse, roomId string, content interface{}) error { +func SendMessage(token Token, roomId string, content interface{}) error { transactionId, err := uuid.NewRandom() if err != nil { return err @@ -29,7 +29,7 @@ func SendMessage(userData LoginResponse, roomId string, content interface{}) err if err != nil { panic(err) } - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", userData.AccessToken)) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) response, err := http.DefaultClient.Do(request) if err != nil { return err diff --git a/api/setpusher.go b/api/setpusher.go index 2766037..c223f86 100644 --- a/api/setpusher.go +++ b/api/setpusher.go @@ -26,7 +26,7 @@ type PusherData struct { Url string `json:"url"` } -func SetPusher(userData LoginResponse, url string) error { +func SetPusher(token Token, url string) error { body, err := json.Marshal(PusherRequest{ AppDisplayName: "webhook", AppId: "de.justjanne.webhook", @@ -52,7 +52,7 @@ func SetPusher(userData LoginResponse, url string) error { if err != nil { panic(err) } - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", userData.AccessToken)) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) response, err := http.DefaultClient.Do(request) if err != nil { return err diff --git a/api/token.go b/api/token.go new file mode 100644 index 0000000..fa9c559 --- /dev/null +++ b/api/token.go @@ -0,0 +1,9 @@ +package api + +import "time" + +type Token struct { + AccessToken string + Expires time.Time + RefreshToken string +} diff --git a/main.go b/main.go index 1e1d5bb..80556d9 100644 --- a/main.go +++ b/main.go @@ -13,15 +13,11 @@ import ( ) func main() { - bot, err := NewMatrixBot( - os.Getenv("BOT_USERNAME"), - os.Getenv("BOT_PASSWORD"), - os.Getenv("BOT_DEVICEID"), - os.Getenv("BOT_PUSHURL"), - ) - if err != nil { - panic(err) - } + var err error + + bot := NewMatrixBot() + + // !8ball handler bot.HandleFunc("!8ball", func(bot *MatrixBot, notification api.Notification) error { positive := []string{ "It is certain", @@ -64,12 +60,14 @@ func main() { } answer := answers[rand.IntN(len(answers))] - err = api.SendMessage(bot.userData, notification.RoomId, api.MessageContent{ + err = api.SendMessage(*bot.token, notification.RoomId, api.MessageContent{ Body: answer, MsgType: "m.text", }) return nil }) + + // !trains handler bahnTpl, err := template.New("bahn").Parse(`{{- /*gotype: bahn.Timetable*/ -}} <b>{{.Station}}</b> Abfahrten {{- range .Stops -}} @@ -103,12 +101,20 @@ func main() { if err != nil { return err } - err = api.SendMessage(bot.userData, notification.RoomId, api.MessageContent{ + err = api.SendMessage(*bot.token, notification.RoomId, api.MessageContent{ FormattedBody: buf.String(), Format: "org.matrix.custom.html", MsgType: "m.text", }) return nil }) + err = bot.Login(os.Getenv("BOT_USERNAME"), os.Getenv("BOT_PASSWORD"), os.Getenv("BOT_DEVICEID")) + if err != nil { + panic(err) + } + err = bot.RegisterPusher(os.Getenv("BOT_PUSHURL")) + if err != nil { + panic(err) + } bot.Serve(":8080") } diff --git a/matrixbot.go b/matrixbot.go index 6538da0..1758061 100644 --- a/matrixbot.go +++ b/matrixbot.go @@ -1,36 +1,69 @@ package main import ( + "fmt" "git.kuschku.de/justjanne/stateless-matrix-bot/api" "io" "log" "net/http" "strings" + "time" ) type MatrixBot struct { - userData api.LoginResponse + token *api.Token handlers map[string]func(bot *MatrixBot, notification api.Notification) error } -func NewMatrixBot( - username string, - password string, - deviceId string, - url string, -) (*MatrixBot, error) { - userData, err := api.Login(username, password, deviceId) +func NewMatrixBot() *MatrixBot { + return &MatrixBot{ + token: nil, + handlers: make(map[string]func(bot *MatrixBot, notification api.Notification) error), + } +} + +func (bot *MatrixBot) RefreshToken() error { + if bot.token == nil { + return fmt.Errorf("no refresh token available") + } + userData, err := api.Refresh(bot.token.RefreshToken) if err != nil { - return nil, err + return err + } + bot.token = &api.Token{ + AccessToken: userData.AccessToken, + Expires: time.Now().Add(time.Duration(userData.ExpiresInMs) / 2 * time.Millisecond), + RefreshToken: userData.RefreshToken, } - err = api.SetPusher(userData, url) + return nil +} + +func (bot *MatrixBot) Login(username string, password string, deviceId string) error { + userData, err := api.Login(username, password, deviceId) if err != nil { - return nil, err + return err } - return &MatrixBot{ - userData: userData, - handlers: make(map[string]func(bot *MatrixBot, notification api.Notification) error), - }, nil + bot.token = &api.Token{ + AccessToken: userData.AccessToken, + Expires: time.Now().Add(time.Duration(userData.ExpiresInMs) / 2 * time.Millisecond), + RefreshToken: userData.RefreshToken, + } + return nil +} + +func (bot *MatrixBot) RefreshTask() { + for true { + if bot.token != nil && time.Now().After(bot.token.Expires) { + if err := bot.RefreshToken(); err != nil { + log.Printf("error refresh token: %s\n", err.Error()) + } + } + time.Sleep(1 * time.Second) + } +} + +func (bot *MatrixBot) RegisterPusher(url string) error { + return api.SetPusher(*bot.token, url) } func (bot *MatrixBot) HandleFunc(command string, handler func(bot *MatrixBot, notification api.Notification) error) { @@ -39,7 +72,7 @@ func (bot *MatrixBot) HandleFunc(command string, handler func(bot *MatrixBot, no func (bot *MatrixBot) Serve(endpoint string) { http.HandleFunc("/healthz", func(writer http.ResponseWriter, request *http.Request) { - io.WriteString(writer, "OK\n") + _, _ = io.WriteString(writer, "OK\n") }) http.HandleFunc("/_matrix/push/v1/notify", func(writer http.ResponseWriter, request *http.Request) { notification, err := api.ParseNotification(request.Body) @@ -50,7 +83,7 @@ func (bot *MatrixBot) Serve(endpoint string) { if notification.EventId == "" { return } - err = api.SetReadReceipt(bot.userData, notification.RoomId, notification.EventId) + err = api.SetReadReceipt(*bot.token, notification.RoomId, notification.EventId) if err != nil { log.Println(err.Error()) return -- GitLab