Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment with worker redirection #606

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions api/_responses/redirect.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,65 @@
package _responses

import (
"crypto/hmac"
"crypto/sha512"
"encoding/hex"
"net/url"
"strconv"
"time"

"github.com/t2bot/matrix-media-repo/api/_apimeta"
"github.com/t2bot/matrix-media-repo/common/rcontext"
)

type RedirectResponse struct {
ToUrl string
}

func Redirect(url string) *RedirectResponse {
return &RedirectResponse{ToUrl: url}
func Redirect(ctx rcontext.RequestContext, toUrl string, auth _apimeta.AuthContext) *RedirectResponse {
if auth.IsAuthenticated() {
// Figure out who is authenticated here, as that affects the expiration time
var expirationTime time.Time
if auth.Server.ServerName != "" {
expirationTime = time.Now().Add(time.Minute)
} else {
expirationTime = time.Now().Add(time.Minute * 5)
}

// Append the expiration time to the URL
toUrl = appendQueryParam(toUrl, "exp", strconv.FormatInt(expirationTime.UnixMilli(), 10))

// Append a value we expect to survive the round trip that only we know about
// We do this after the expiration value to cover that field as well.
mac := hmac.New(sha512.New, []byte("THIS IS ANOTHER SECRET VALUE")) // TODO: @@ Actual secret key
mac.Write([]byte(toUrl))
requestHmac := mac.Sum(nil)
toUrl = appendQueryParam(toUrl, "request", hex.EncodeToString(requestHmac)+"."+hex.EncodeToString([]byte(toUrl)))

// Prepare our HMAC message contents as a JSON object
hmacMessage := toUrl + "||"
if auth.User.UserId != "" {
hmacMessage += auth.User.AccessToken
}

// Actually do the HMAC
mac = hmac.New(sha512.New, []byte("THIS_IS_A_SECRET_KEY")) // TODO: @@ Actual secret key
mac.Write([]byte(hmacMessage))
verifyHmac := mac.Sum(nil)

// Append the HMAC to the URL
toUrl = appendQueryParam(toUrl, "verify", hex.EncodeToString(verifyHmac))
}
return &RedirectResponse{ToUrl: toUrl}
}

func appendQueryParam(toUrl string, key string, val string) string {
parsedUrl, err := url.Parse(toUrl)
if err != nil {
panic(err) // it wouldn't have worked anyways
}
qs := parsedUrl.Query()
qs.Set(key, val)
turt2live marked this conversation as resolved.
Show resolved Hide resolved
parsedUrl.RawQuery = qs.Encode()
return parsedUrl.String()
}
3 changes: 2 additions & 1 deletion api/_routers/98-use-rcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ beforeParseDownload:
}

if shouldCache {
headers.Set("Cache-Control", "private, max-age=259200") // 3 days
// TODO: @@ Only set `public` for CDNs, otherwise use `private`
headers.Set("Cache-Control", "public, max-age=259200") // 3 days
}

if downloadRes.SizeBytes > 0 {
Expand Down
117 changes: 117 additions & 0 deletions api/custom/byid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package custom

import (
"crypto/hmac"
"crypto/sha512"
"encoding/hex"
"net/http"
"net/url"
"strconv"
"strings"

"github.com/t2bot/matrix-media-repo/api/_apimeta"
"github.com/t2bot/matrix-media-repo/api/_responses"
"github.com/t2bot/matrix-media-repo/api/_routers"
"github.com/t2bot/matrix-media-repo/common/rcontext"
"github.com/t2bot/matrix-media-repo/database"
"github.com/t2bot/matrix-media-repo/datastores"
"github.com/t2bot/matrix-media-repo/pipelines/_steps/download"
"github.com/t2bot/matrix-media-repo/util"
)

func GetMediaById(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
//if !user.IsShared {
// return _responses.AuthFailed()
//}

// TODO: This is beyond dangerous and needs proper filtering

// Parse the `request` to ensure we actually sent this request
requestVal := r.URL.Query().Get("request")
requestValParts := strings.Split(requestVal, ".")
if len(requestValParts) != 2 {
rctx.Log.Error("Need exactly 2 parts for `request`")
return _responses.AuthFailed()
}
verifyMac := requestValParts[0]
toUrlB, err := hex.DecodeString(requestValParts[1])
if err != nil {
rctx.Log.Error("Failed to decode request value:", err)
return _responses.AuthFailed()
}
toUrl := string(toUrlB)
mac := hmac.New(sha512.New, []byte("THIS IS ANOTHER SECRET VALUE")) // TODO: @@ Actual secret key
mac.Write([]byte(toUrl))
expectedMac := hex.EncodeToString(mac.Sum(nil))
if strings.ToLower(verifyMac) != strings.ToLower(expectedMac) {
return _responses.AuthFailed()
}

// Verify the HMAC from the worker too
query := r.URL.Query()
suppliedHmac := query.Get("verify")
query.Del("verify")
r.URL.RawQuery = query.Encode()
r.URL.Host = r.Host // TODO: Why is this unset??
r.URL.Scheme = "https" // TODO: Why is this unset??
mac = hmac.New(sha512.New, []byte("THIS_IS_A_SECRET_KEY")) // TODO: @@ Actual secret key
rctx.Log.Info("URL: ", r.URL.String())
mac.Write([]byte(r.URL.String()))
expectedMac = hex.EncodeToString(mac.Sum(nil))
if strings.ToLower(suppliedHmac) != strings.ToLower(expectedMac) {
rctx.Log.Error("HMAC mismatch")
return _responses.AuthFailed()
}

// Verify that the path for the `request` is the same as our called path
parsedUrl, err := url.Parse(toUrl)
if err != nil {
rctx.Log.Error("Failed to parse URL:", err)
return _responses.AuthFailed()
}
if parsedUrl.Path != r.URL.Path {
rctx.Log.Error("Wrong path or query")
return _responses.AuthFailed()
}

// Verify that the original request isn't expired
expVal := parsedUrl.Query().Get("exp")
if expVal != "" {
exp, err := strconv.ParseInt(expVal, 10, 64)
if err != nil {
rctx.Log.Error("Failed to parse exp:", err)
return _responses.AuthFailed()
}
if exp <= util.NowMillis() {
rctx.Log.Error("Request expired")
return _responses.AuthFailed()
}
}

// ---- request verified - we can now serve the media ----

db := database.GetInstance().Media.Prepare(rctx)
ds, err := datastores.Pick(rctx, datastores.LocalMediaKind)
if err != nil {
panic(err)
}
objectId := _routers.GetParam("objectId", r)
medias, err := db.GetByLocation(ds.Id, objectId)
if err != nil {
panic(err)
}

media := medias[0]
stream, err := download.OpenStream(rctx, media.Locatable)
if err != nil {
panic(err)
}

return &_responses.DownloadResponse{
ContentType: media.ContentType,
Filename: media.UploadName,
SizeBytes: media.SizeBytes,
Data: stream,
TargetDisposition: "infer",
}
}
2 changes: 1 addition & 1 deletion api/r0/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta.
} else if errors.Is(err, common.ErrMediaNotYetUploaded) {
return _responses.NotYetUploaded()
} else if errors.As(err, &redirect) {
return _responses.Redirect(redirect.RedirectUrl)
return _responses.Redirect(rctx, redirect.RedirectUrl, auth)
}
rctx.Log.Error("Unexpected error locating media: ", err)
sentry.CaptureException(err)
Expand Down
2 changes: 1 addition & 1 deletion api/r0/thumbnail.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta
}
}
} else if errors.As(err, &redirect) {
return _responses.Redirect(redirect.RedirectUrl)
return _responses.Redirect(rctx, redirect.RedirectUrl, auth)
}
rctx.Log.Error("Unexpected error locating media: ", err)
sentry.CaptureException(err)
Expand Down
2 changes: 2 additions & 0 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
const PrefixMedia = "/_matrix/media"
const PrefixClient = "/_matrix/client"
const PrefixFederation = "/_matrix/federation"
const PrefixMMR = "/_mmr"

func buildRoutes() http.Handler {
counter := &_routers.RequestCounter{}
Expand Down Expand Up @@ -61,6 +62,7 @@ func buildRoutes() http.Handler {
purgeOneRoute := makeRoute(_routers.RequireAccessToken(custom.PurgeIndividualRecord, false), "purge_individual_media", counter)
register([]string{"DELETE"}, PrefixMedia, "download/:server/:mediaId", mxUnstable, router, purgeOneRoute)
register([]string{"GET"}, PrefixMedia, "usage", msc4034, router, makeRoute(_routers.RequireAccessToken(unstable.PublicUsage, false), "usage", counter))
register([]string{"GET"}, PrefixMMR, "byid/:objectId", mxNoVersion, router, makeRoute(_routers.OptionalAccessToken(custom.GetMediaById), "byid", counter))

// Custom and top-level features
router.Handler("GET", fmt.Sprintf("%s/version", PrefixMedia), makeRoute(_routers.OptionalAccessToken(custom.GetVersion), "get_version", counter))
Expand Down
Loading