diff --git a/config.sample.yaml b/config.sample.yaml index 5e037828..9720dc64 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -224,6 +224,20 @@ datastores: # capability. MMR may still be responsible for bandwidth charges incurred from going to # the bucket directly. #publicBaseUrl: "https://mycdn.example.org/" + # When set, the public S3 URL will be presigned before redirection. This allows users + # to directly download from private S3 buckets as long as the URL is still valid. + #presignUrl: false + # The time in seconds that a presigned URL will be valid for before expiring. + # This value must be between 60 (1 minute) and 604800 (7 days). Ensure it is high + # enough that users with slow connections will be able to download the media before it expires. + #presignExpiry: 86400 + # When set, the presigned S3 URLs will be cached and reused as long as they are still valid. + # Otherwise, clients will not be able to use their cache, and will redownload the same file. + #cachePresignedUrls: true + # The time in seconds that a presigned URL will be cached and reused. If the URL is older, + # a new URL will be generated and cached in its place. + # It is recommended to use a value a bit less than presignExpiry. + #presignCacheExpiry: 79200 # Set to `true` to bypass any local cache when `publicBaseUrl` is set. Has no effect # when `publicBaseUrl` is unset. Defaults to false (cached media will be served by MMR # before redirection if present). diff --git a/datastores/download.go b/datastores/download.go index 3470176e..bb976ceb 100644 --- a/datastores/download.go +++ b/datastores/download.go @@ -6,12 +6,14 @@ import ( "io" "os" "path" + "time" "github.com/minio/minio-go/v7" "github.com/prometheus/client_golang/prometheus" "github.com/t2bot/matrix-media-repo/common/config" "github.com/t2bot/matrix-media-repo/common/rcontext" "github.com/t2bot/matrix-media-repo/metrics" + "github.com/t2bot/matrix-media-repo/redislib" ) func Download(ctx rcontext.RequestContext, ds config.DatastoreConfig, dsFileName string) (io.ReadSeekCloser, error) { @@ -49,12 +51,43 @@ func DownloadOrRedirect(ctx rcontext.RequestContext, ds config.DatastoreConfig, if s3c.publicBaseUrl != "" { metrics.S3Operations.With(prometheus.Labels{"operation": "RedirectGetObject"}).Inc() - return nil, redirect(fmt.Sprintf("%s%s", s3c.publicBaseUrl, dsFileName)) + if s3c.presignUrl { + presignedUrl, err := PresignURL(ctx, ds, s3c, dsFileName) + if err != nil { + return nil, err + } + return nil, redirect(presignedUrl) + } else { + return nil, redirect(fmt.Sprintf("%s%s", s3c.publicBaseUrl, dsFileName)) + } } return Download(ctx, ds, dsFileName) } +func PresignURL(ctx rcontext.RequestContext, ds config.DatastoreConfig, s3c *s3, dsFileName string) (string, error) { + url, err := redislib.TryGetURL(ctx, dsFileName) + if err != nil { + ctx.Log.Debug("Unable to fetch url from cache due to error: ", err) + } + if len(url) == 0 || err != nil { + presignedUrl, err := s3c.client.PresignedGetObject(ctx.Context, s3c.bucket, dsFileName, time.Duration(s3c.presignExpiry)*time.Second, nil) + if err != nil { + return "", err + } + presignedUrlStr := presignedUrl.String() + ctx.Log.Debug("Caching presigned url for: ", dsFileName) + err = redislib.StoreURL(ctx, dsFileName, presignedUrlStr, time.Duration(s3c.presignCacheExpiry)*time.Second) + if err != nil { + ctx.Log.Debug("Not populating url cache due to error: ", err) + } + return presignedUrlStr, nil + } else { + ctx.Log.Debug("Using cached presigned url for: ", dsFileName) + return url, nil + } +} + func WouldRedirectWhenCached(ctx rcontext.RequestContext, ds config.DatastoreConfig) (bool, error) { if ds.Type != "s3" { return false, nil diff --git a/datastores/s3.go b/datastores/s3.go index 99f15ae3..e4c08275 100644 --- a/datastores/s3.go +++ b/datastores/s3.go @@ -21,6 +21,10 @@ type s3 struct { storageClass string bucket string publicBaseUrl string + presignUrl bool + presignExpiry int + cachePresignedUrls bool + presignCacheExpiry int redirectWhenCached bool prefixLength int multipartUploads bool @@ -43,6 +47,10 @@ func getS3(ds config.DatastoreConfig) (*s3, error) { storageClass, hasStorageClass := ds.Options["storageClass"] useSslStr, hasSsl := ds.Options["ssl"] publicBaseUrl := ds.Options["publicBaseUrl"] + presignUrlStr, hasPresignUrl := ds.Options["presignUrl"] + presignExpiryStr, hasPresignExpiry := ds.Options["presignExpiry"] + cachePresignedUrlsStr, hasCachePresignedUrls := ds.Options["cachePresignedUrls"] + presignCacheExpiryStr, hasPresignCacheExpiry := ds.Options["presignCacheExpiry"] redirectWhenCachedStr, hasRedirectWhenCached := ds.Options["redirectWhenCached"] prefixLengthStr, hasPrefixLength := ds.Options["prefixLength"] useMultipartStr, hasMultipart := ds.Options["multipartUploads"] @@ -61,6 +69,38 @@ func getS3(ds config.DatastoreConfig) (*s3, error) { useMultipart, _ = strconv.ParseBool(useMultipartStr) } + presignUrl := false + if hasPresignUrl && presignUrlStr != "" { + presignUrl, _ = strconv.ParseBool(presignUrlStr) + } + + presignExpiry := 86400 + if hasPresignExpiry && presignExpiryStr != "" { + presignExpiry, _ = strconv.Atoi(presignExpiryStr) + if presignExpiry < 60 { + presignExpiry = 60 + } + if presignExpiry > 604800 { + presignExpiry = 604800 + } + } + + cachePresignedUrls := true + if hasCachePresignedUrls && cachePresignedUrlsStr != "" { + cachePresignedUrls, _ = strconv.ParseBool(cachePresignedUrlsStr) + } + + presignCacheExpiry := presignExpiry * 2 / 3 + if hasPresignCacheExpiry && presignCacheExpiryStr != "" { + presignCacheExpiry, _ = strconv.Atoi(presignCacheExpiryStr) + if presignCacheExpiry >= presignExpiry { + presignCacheExpiry = presignExpiry * 2 / 3 + } + if presignCacheExpiry < 0 { + presignCacheExpiry = 0 + } + } + redirectWhenCached := false if hasRedirectWhenCached && redirectWhenCachedStr != "" { redirectWhenCached, _ = strconv.ParseBool(redirectWhenCachedStr) @@ -93,6 +133,10 @@ func getS3(ds config.DatastoreConfig) (*s3, error) { storageClass: storageClass, bucket: bucket, publicBaseUrl: publicBaseUrl, + presignUrl: presignUrl, + presignExpiry: presignExpiry, + cachePresignedUrls: cachePresignedUrls, + presignCacheExpiry: presignCacheExpiry, redirectWhenCached: redirectWhenCached, prefixLength: prefixLength, multipartUploads: useMultipart, diff --git a/redislib/presign_cache.go b/redislib/presign_cache.go new file mode 100644 index 00000000..70caae79 --- /dev/null +++ b/redislib/presign_cache.go @@ -0,0 +1,68 @@ +package redislib + +import ( + "context" + "time" + + "github.com/getsentry/sentry-go" + "github.com/redis/go-redis/v9" + "github.com/t2bot/matrix-media-repo/common/rcontext" +) + +const keyPrefix = "s3url:" + +func StoreURL(ctx rcontext.RequestContext, dsFileName string, url string, expiration time.Duration) error { + makeConnection() + if ring == nil { + return nil + } + + if err := ring.ForEachShard(ctx.Context, func(ctx2 context.Context, client *redis.Client) error { + res := client.Set(ctx2, keyPrefix+dsFileName, url, expiration) + return res.Err() + }); err != nil { + if delErr := DeleteURL(ctx, keyPrefix+dsFileName); delErr != nil { + ctx.Log.Warn("Error while attempting to clean up url cache during another error: ", delErr) + sentry.CaptureException(delErr) + } + return err + } + + return nil +} + +func TryGetURL(ctx rcontext.RequestContext, dsFileName string) (string, error) { + makeConnection() + if ring == nil { + return "", nil + } + + timeoutCtx, cancel := context.WithTimeout(ctx.Context, 20*time.Second) + defer cancel() + + var result *redis.StringCmd + + ctx.Log.Debugf("Getting cached s3 url for %s", keyPrefix+dsFileName) + result = ring.Get(timeoutCtx, keyPrefix+dsFileName) + + s, err := result.Result() + if err != nil { + if err == redis.Nil { + return "", nil + } + return "", err + } + + return s, nil +} + +func DeleteURL(ctx rcontext.RequestContext, dsFileName string) error { + makeConnection() + if ring == nil { + return nil + } + + return ring.ForEachShard(ctx, func(ctx2 context.Context, client *redis.Client) error { + return client.Del(ctx2, keyPrefix+dsFileName).Err() + }) +}