diff --git a/api/_responses/redirect.go b/api/_responses/redirect.go index 1e8c66c4..a5451d23 100644 --- a/api/_responses/redirect.go +++ b/api/_responses/redirect.go @@ -1,9 +1,65 @@ package _responses +import ( + "crypto/hmac" + "crypto/sha512" + "encoding/hex" + "net/url" + "strconv" + "time" + + "github.com/t2bot/matrix-media-repo/api/_apimeta" + "github.com/t2bot/matrix-media-repo/common/rcontext" +) + type RedirectResponse struct { ToUrl string } -func Redirect(url string) *RedirectResponse { - return &RedirectResponse{ToUrl: url} +func Redirect(ctx rcontext.RequestContext, toUrl string, auth _apimeta.AuthContext) *RedirectResponse { + if auth.IsAuthenticated() { + // Figure out who is authenticated here, as that affects the expiration time + var expirationTime time.Time + if auth.Server.ServerName != "" { + expirationTime = time.Now().Add(time.Minute) + } else { + expirationTime = time.Now().Add(time.Minute * 5) + } + + // Append the expiration time to the URL + toUrl = appendQueryParam(toUrl, "exp", strconv.FormatInt(expirationTime.UnixMilli(), 10)) + + // Append a value we expect to survive the round trip that only we know about + // We do this after the expiration value to cover that field as well. + mac := hmac.New(sha512.New, []byte("THIS IS ANOTHER SECRET VALUE")) // TODO: @@ Actual secret key + mac.Write([]byte(toUrl)) + requestHmac := mac.Sum(nil) + toUrl = appendQueryParam(toUrl, "request", hex.EncodeToString(requestHmac)+"."+hex.EncodeToString([]byte(toUrl))) + + // Prepare our HMAC message contents as a JSON object + hmacMessage := toUrl + "||" + if auth.User.UserId != "" { + hmacMessage += auth.User.AccessToken + } + + // Actually do the HMAC + mac = hmac.New(sha512.New, []byte("THIS_IS_A_SECRET_KEY")) // TODO: @@ Actual secret key + mac.Write([]byte(hmacMessage)) + verifyHmac := mac.Sum(nil) + + // Append the HMAC to the URL + toUrl = appendQueryParam(toUrl, "verify", hex.EncodeToString(verifyHmac)) + } + return &RedirectResponse{ToUrl: toUrl} +} + +func appendQueryParam(toUrl string, key string, val string) string { + parsedUrl, err := url.Parse(toUrl) + if err != nil { + panic(err) // it wouldn't have worked anyways + } + qs := parsedUrl.Query() + qs.Set(key, val) + parsedUrl.RawQuery = qs.Encode() + return parsedUrl.String() } diff --git a/api/_routers/98-use-rcontext.go b/api/_routers/98-use-rcontext.go index 07523153..5bf49a22 100644 --- a/api/_routers/98-use-rcontext.go +++ b/api/_routers/98-use-rcontext.go @@ -129,7 +129,8 @@ beforeParseDownload: } if shouldCache { - headers.Set("Cache-Control", "private, max-age=259200") // 3 days + // TODO: @@ Only set `public` for CDNs, otherwise use `private` + headers.Set("Cache-Control", "public, max-age=259200") // 3 days } if downloadRes.SizeBytes > 0 { diff --git a/api/custom/byid.go b/api/custom/byid.go new file mode 100644 index 00000000..2af1657b --- /dev/null +++ b/api/custom/byid.go @@ -0,0 +1,117 @@ +package custom + +import ( + "crypto/hmac" + "crypto/sha512" + "encoding/hex" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/t2bot/matrix-media-repo/api/_apimeta" + "github.com/t2bot/matrix-media-repo/api/_responses" + "github.com/t2bot/matrix-media-repo/api/_routers" + "github.com/t2bot/matrix-media-repo/common/rcontext" + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/datastores" + "github.com/t2bot/matrix-media-repo/pipelines/_steps/download" + "github.com/t2bot/matrix-media-repo/util" +) + +func GetMediaById(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + //if !user.IsShared { + // return _responses.AuthFailed() + //} + + // TODO: This is beyond dangerous and needs proper filtering + + // Parse the `request` to ensure we actually sent this request + requestVal := r.URL.Query().Get("request") + requestValParts := strings.Split(requestVal, ".") + if len(requestValParts) != 2 { + rctx.Log.Error("Need exactly 2 parts for `request`") + return _responses.AuthFailed() + } + verifyMac := requestValParts[0] + toUrlB, err := hex.DecodeString(requestValParts[1]) + if err != nil { + rctx.Log.Error("Failed to decode request value:", err) + return _responses.AuthFailed() + } + toUrl := string(toUrlB) + mac := hmac.New(sha512.New, []byte("THIS IS ANOTHER SECRET VALUE")) // TODO: @@ Actual secret key + mac.Write([]byte(toUrl)) + expectedMac := hex.EncodeToString(mac.Sum(nil)) + if strings.ToLower(verifyMac) != strings.ToLower(expectedMac) { + return _responses.AuthFailed() + } + + // Verify the HMAC from the worker too + query := r.URL.Query() + suppliedHmac := query.Get("verify") + query.Del("verify") + r.URL.RawQuery = query.Encode() + r.URL.Host = r.Host // TODO: Why is this unset?? + r.URL.Scheme = "https" // TODO: Why is this unset?? + mac = hmac.New(sha512.New, []byte("THIS_IS_A_SECRET_KEY")) // TODO: @@ Actual secret key + rctx.Log.Info("URL: ", r.URL.String()) + mac.Write([]byte(r.URL.String())) + expectedMac = hex.EncodeToString(mac.Sum(nil)) + if strings.ToLower(suppliedHmac) != strings.ToLower(expectedMac) { + rctx.Log.Error("HMAC mismatch") + return _responses.AuthFailed() + } + + // Verify that the path for the `request` is the same as our called path + parsedUrl, err := url.Parse(toUrl) + if err != nil { + rctx.Log.Error("Failed to parse URL:", err) + return _responses.AuthFailed() + } + if parsedUrl.Path != r.URL.Path { + rctx.Log.Error("Wrong path or query") + return _responses.AuthFailed() + } + + // Verify that the original request isn't expired + expVal := parsedUrl.Query().Get("exp") + if expVal != "" { + exp, err := strconv.ParseInt(expVal, 10, 64) + if err != nil { + rctx.Log.Error("Failed to parse exp:", err) + return _responses.AuthFailed() + } + if exp <= util.NowMillis() { + rctx.Log.Error("Request expired") + return _responses.AuthFailed() + } + } + + // ---- request verified - we can now serve the media ---- + + db := database.GetInstance().Media.Prepare(rctx) + ds, err := datastores.Pick(rctx, datastores.LocalMediaKind) + if err != nil { + panic(err) + } + objectId := _routers.GetParam("objectId", r) + medias, err := db.GetByLocation(ds.Id, objectId) + if err != nil { + panic(err) + } + + media := medias[0] + stream, err := download.OpenStream(rctx, media.Locatable) + if err != nil { + panic(err) + } + + return &_responses.DownloadResponse{ + ContentType: media.ContentType, + Filename: media.UploadName, + SizeBytes: media.SizeBytes, + Data: stream, + TargetDisposition: "infer", + } +} diff --git a/api/r0/download.go b/api/r0/download.go index a4908025..1ddf9460 100644 --- a/api/r0/download.go +++ b/api/r0/download.go @@ -112,7 +112,7 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta. } else if errors.Is(err, common.ErrMediaNotYetUploaded) { return _responses.NotYetUploaded() } else if errors.As(err, &redirect) { - return _responses.Redirect(redirect.RedirectUrl) + return _responses.Redirect(rctx, redirect.RedirectUrl, auth) } rctx.Log.Error("Unexpected error locating media: ", err) sentry.CaptureException(err) diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 322b3ea8..e343ae87 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -184,7 +184,7 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta } } } else if errors.As(err, &redirect) { - return _responses.Redirect(redirect.RedirectUrl) + return _responses.Redirect(rctx, redirect.RedirectUrl, auth) } rctx.Log.Error("Unexpected error locating media: ", err) sentry.CaptureException(err) diff --git a/api/routes.go b/api/routes.go index c5fd4cdb..f1dad9cd 100644 --- a/api/routes.go +++ b/api/routes.go @@ -19,6 +19,7 @@ import ( const PrefixMedia = "/_matrix/media" const PrefixClient = "/_matrix/client" const PrefixFederation = "/_matrix/federation" +const PrefixMMR = "/_mmr" func buildRoutes() http.Handler { counter := &_routers.RequestCounter{} @@ -61,6 +62,7 @@ func buildRoutes() http.Handler { purgeOneRoute := makeRoute(_routers.RequireAccessToken(custom.PurgeIndividualRecord, false), "purge_individual_media", counter) register([]string{"DELETE"}, PrefixMedia, "download/:server/:mediaId", mxUnstable, router, purgeOneRoute) register([]string{"GET"}, PrefixMedia, "usage", msc4034, router, makeRoute(_routers.RequireAccessToken(unstable.PublicUsage, false), "usage", counter)) + register([]string{"GET"}, PrefixMMR, "byid/:objectId", mxNoVersion, router, makeRoute(_routers.OptionalAccessToken(custom.GetMediaById), "byid", counter)) // Custom and top-level features router.Handler("GET", fmt.Sprintf("%s/version", PrefixMedia), makeRoute(_routers.OptionalAccessToken(custom.GetVersion), "get_version", counter))