Skip to content
Snippets Groups Projects
main.go 3.09 KiB
package main

import (
	"context"
	"fmt"
	arg "github.com/alexflint/go-arg"
	"io"
	"log"
	"net/http"
	"os"
	"regexp"
	"rss-cache-proxy/util"
	"strings"
	"time"
)

type Options struct {
	CacheDir  string        `arg:"required"`
	Url       string        `arg:"required"`
	UserAgent string        `default:"Mozilla/5.0 (compatible; rss-cache-proxy/0.1.0; +https://git.kuschku.de/justjanne/rss-cache-proxy)"`
	Interval  time.Duration `default:"60m"`
	Pattern   string        `default:""`
}

func respondError(writer http.ResponseWriter, request *http.Request, err error) {
	if os.IsNotExist(err) {
		log.Printf("error processing request for %s: %s\n", request.URL, err.Error())
		writer.WriteHeader(404)
		_, _ = fmt.Fprintf(writer, "Not Found\n")
	} else if err != nil {
		log.Printf("error processing request for %s: %s\n", request.URL, err.Error())
		writer.WriteHeader(502)
		_, err = fmt.Fprintf(writer, "Internal Server Error: %s\n", err)
	}
}

func main() {
	var options Options
	arg.MustParse(&options)

	var cache *Cache
	if len(options.Pattern) > 0 {
		cache = NewCacheWithFilter(options.CacheDir, regexp.MustCompile(options.Pattern))
	} else {
		cache = NewCache(options.CacheDir)
	}

	lastModified, err := cache.LastModified()
	if err != nil {
		lastModified = time.UnixMicro(0)
	}
	stopTimer := util.IntervalTimer(context.Background(), lastModified, options.Interval, func(ctx context.Context) {
		request, err := http.NewRequestWithContext(ctx, "GET", options.Url, nil)
		if err != nil {
			log.Printf("error building request for %s: %s\n", options.Url, err.Error())
			return
		}
		request.Header.Set("User-Agent", options.UserAgent)
		log.Printf("fetching content for %s…\n", options.Url)
		response, err := http.DefaultClient.Do(request)
		if err != nil {
			log.Printf("error fetching content for %s: %s\n", options.Url, err.Error())
			return
		}
		err = cache.Write(response.Body)
		if err != nil {
			log.Printf("error updating cache for %s: %s\n", options.Url, err.Error())
			return
		}
		log.Printf("finished updating cache for %s\n", options.Url)
	})
	defer stopTimer()

	http.HandleFunc("/rss.xml", func(writer http.ResponseWriter, request *http.Request) {
		cacheEntry, err := cache.Read()
		if err != nil {
			respondError(writer, request, err)
			return
		}
		for _, value := range request.Header.Values("If-Modified-Since") {
			if value == cacheEntry.LastModifiedString() {
				writer.WriteHeader(304)
				_ = cacheEntry.Content.Close()
				return
			}
		}
		for _, value := range request.Header.Values("If-None-Match") {
			for _, etag := range strings.Split(value, ",") {
				etag = strings.Trim(strings.TrimSpace(etag), "\"")
				if cacheEntry.Checksum == etag {
					writer.WriteHeader(304)
					_ = cacheEntry.Content.Close()
					return
				}
			}
		}
		writer.Header().Set("Last-Modified", cacheEntry.LastModifiedString())
		writer.Header().Set("ETag", cacheEntry.Checksum)
		_, _ = io.Copy(writer, cacheEntry.Content)
		_ = cacheEntry.Content.Close()
	})
	http.HandleFunc("/healthz", func(writer http.ResponseWriter, request *http.Request) {
		_, _ = fmt.Fprintf(writer, "ok\n")
	})
	log.Fatal(http.ListenAndServe(":8080", nil))
}