package main

import (
	"crypto/sha256"
	"fmt"
	"hash"
	"io"
	"os"
	"path"
	"regexp"
	"rss-cache-proxy/util"
	"time"
)

type Cache struct {
	cacheDir      string
	writerFactory func(file io.WriteCloser, digest hash.Hash) io.WriteCloser
}

type CacheEntry struct {
	Checksum     string
	LastModified time.Time
	Content      io.ReadCloser
}

func (cacheEntry CacheEntry) LastModifiedString() string {
	return cacheEntry.LastModified.Format("Mon, 02 Jan 2006 15:04:05 GMT")
}

func NewCache(cacheDir string) *Cache {
	return &Cache{
		cacheDir: cacheDir,
		writerFactory: func(file io.WriteCloser, digest hash.Hash) io.WriteCloser {
			return util.MultiWriteCloser(file, digest)
		},
	}
}

func NewCacheWithFilter(cacheDir string, pattern *regexp.Regexp) *Cache {
	return &Cache{
		cacheDir: cacheDir,
		writerFactory: func(file io.WriteCloser, digest hash.Hash) io.WriteCloser {
			return util.MultiWriteCloser(file, util.NewFilteredWriter(digest, pattern))
		},
	}
}

func (c *Cache) writeWithHash(writer io.WriteCloser, reader io.ReadCloser) (string, error) {
	hasher := sha256.New()
	multiWriter := c.writerFactory(writer, hasher)
	if _, err := io.Copy(multiWriter, reader); err != nil {
		return "", err
	}
	if err := multiWriter.Close(); err != nil {
		return "", err
	}
	if err := reader.Close(); err != nil {
		return "", err
	}
	return fmt.Sprintf("%x", hasher.Sum(nil)), nil
}

func (c *Cache) LastModified() (time.Time, error) {
	indexFile, err := os.Stat(path.Join(c.cacheDir, "current"))
	if err != nil {
		return time.UnixMicro(0), err
	}
	return indexFile.ModTime(), nil
}

func (c *Cache) Write(content io.ReadCloser) error {
	contentFile, err := os.CreateTemp(c.cacheDir, "tmp-content")
	if err != nil {
		return err
	}

	hasher := sha256.New()
	multiWriter := c.writerFactory(contentFile, hasher)
	if _, err = io.Copy(multiWriter, content); err != nil {
		return err
	}
	_ = multiWriter.Close()
	_ = content.Close()
	checksum := fmt.Sprintf("%x", hasher.Sum(nil))

	filePath := path.Join(c.cacheDir, checksum)
	if _, err := os.Stat(filePath); os.IsNotExist(err) {
		err = os.Rename(contentFile.Name(), filePath)
		if err != nil {
			return err
		}
	} else {
		err = os.Remove(contentFile.Name())
		if err != nil {
			return err
		}
	}

	indexFile, err := os.CreateTemp(c.cacheDir, "tmp-index")
	if err != nil {
		return err
	}
	_, err = io.WriteString(indexFile, checksum)
	_ = indexFile.Close()
	if err != nil {
		return err
	}

	indexPath := path.Join(c.cacheDir, "current")
	err = os.Rename(indexFile.Name(), indexPath)
	if err != nil {
		return err
	}
	return nil
}

func (c *Cache) Read() (CacheEntry, error) {
	indexFile, err := os.Open(path.Join(c.cacheDir, "current"))
	if err != nil {
		return CacheEntry{}, err
	}
	indexContent, err := io.ReadAll(indexFile)
	_ = indexFile.Close()
	if err != nil {
		return CacheEntry{}, err
	}
	currentVersion := string(indexContent)
	currentFile, err := os.Open(path.Join(c.cacheDir, currentVersion))
	if err != nil {
		return CacheEntry{}, err
	}
	currentInfo, err := currentFile.Stat()
	if err != nil {
		_ = currentFile.Close()
		return CacheEntry{}, err
	}
	return CacheEntry{
		Checksum:     currentVersion,
		LastModified: currentInfo.ModTime(),
		Content:      currentFile,
	}, nil
}