package main import ( "crypto/sha256" "fmt" "hash" "io" "os" "path" "regexp" "rss-cache-proxy/util" "time" ) type Cache struct { cacheDir string writerFactory func(file io.WriteCloser, digest hash.Hash) io.WriteCloser } type CacheEntry struct { Checksum string LastModified time.Time Content io.ReadCloser } func (cacheEntry CacheEntry) LastModifiedString() string { return cacheEntry.LastModified.Format("Mon, 02 Jan 2006 15:04:05 GMT") } func NewCache(cacheDir string) *Cache { return &Cache{ cacheDir: cacheDir, writerFactory: func(file io.WriteCloser, digest hash.Hash) io.WriteCloser { return util.MultiWriteCloser(file, digest) }, } } func NewCacheWithFilter(cacheDir string, pattern *regexp.Regexp) *Cache { return &Cache{ cacheDir: cacheDir, writerFactory: func(file io.WriteCloser, digest hash.Hash) io.WriteCloser { return util.MultiWriteCloser(file, util.NewFilteredWriter(digest, pattern)) }, } } func (c *Cache) writeWithHash(writer io.WriteCloser, reader io.ReadCloser) (string, error) { hasher := sha256.New() multiWriter := c.writerFactory(writer, hasher) if _, err := io.Copy(multiWriter, reader); err != nil { return "", err } if err := multiWriter.Close(); err != nil { return "", err } if err := reader.Close(); err != nil { return "", err } return fmt.Sprintf("%x", hasher.Sum(nil)), nil } func (c *Cache) LastModified() (time.Time, error) { indexFile, err := os.Stat(path.Join(c.cacheDir, "current")) if err != nil { return time.UnixMicro(0), err } return indexFile.ModTime(), nil } func (c *Cache) Write(content io.ReadCloser) error { contentFile, err := os.CreateTemp(c.cacheDir, "tmp-content") if err != nil { return err } hasher := sha256.New() multiWriter := c.writerFactory(contentFile, hasher) if _, err = io.Copy(multiWriter, content); err != nil { return err } _ = multiWriter.Close() _ = content.Close() checksum := fmt.Sprintf("%x", hasher.Sum(nil)) filePath := path.Join(c.cacheDir, checksum) if _, err := os.Stat(filePath); os.IsNotExist(err) { err = os.Rename(contentFile.Name(), filePath) if err != nil { return err } } else { err = os.Remove(contentFile.Name()) if err != nil { return err } } indexFile, err := os.CreateTemp(c.cacheDir, "tmp-index") if err != nil { return err } _, err = io.WriteString(indexFile, checksum) _ = indexFile.Close() if err != nil { return err } indexPath := path.Join(c.cacheDir, "current") err = os.Rename(indexFile.Name(), indexPath) if err != nil { return err } return nil } func (c *Cache) Read() (CacheEntry, error) { indexFile, err := os.Open(path.Join(c.cacheDir, "current")) if err != nil { return CacheEntry{}, err } indexContent, err := io.ReadAll(indexFile) _ = indexFile.Close() if err != nil { return CacheEntry{}, err } currentVersion := string(indexContent) currentFile, err := os.Open(path.Join(c.cacheDir, currentVersion)) if err != nil { return CacheEntry{}, err } currentInfo, err := currentFile.Stat() if err != nil { _ = currentFile.Close() return CacheEntry{}, err } return CacheEntry{ Checksum: currentVersion, LastModified: currentInfo.ModTime(), Content: currentFile, }, nil }