diff --git a/commands/decode.go b/commands/decode.go index 4c63468..cc85fee 100644 --- a/commands/decode.go +++ b/commands/decode.go @@ -8,6 +8,8 @@ import ( "github.com/disgoorg/disgo/discord" "github.com/disgoorg/disgo/handler" "github.com/disgoorg/json" + + "github.com/lavalink-devs/lavalink-bot/internal/trackdecode" ) var decode = discord.SlashCommandCreate{ @@ -19,11 +21,36 @@ var decode = discord.SlashCommandCreate{ Description: "The base64 encoded lavalink track", Required: true, }, + discord.ApplicationCommandOptionBool{ + Name: "lavalink", + Description: "If the Lavalink server should be used for decoding", + Required: false, + }, }, } func (c *Commands) Decode(e *handler.CommandEvent) error { track := e.SlashCommandInteractionData().String("track") + lavalink := e.SlashCommandInteractionData().Bool("lavalink") + + if !lavalink { + var content string + decoded, version, err := trackdecode.DecodeString(track) + if err != nil { + content += fmt.Sprintf("error while decoding track: %s\n", err.Error()) + } + if version > 0 { + content += fmt.Sprintf("track was encoded with version: `%d`\n", version) + } + if decoded != nil { + data, _ := json.MarshalIndent(decoded, "", " ") + content += fmt.Sprintf("```json\n%s\n```", data) + } + + return e.CreateMessage(discord.MessageCreate{ + Content: content, + }) + } if err := e.DeferCreateMessage(false); err != nil { return err @@ -31,7 +58,7 @@ func (c *Commands) Decode(e *handler.CommandEvent) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - decoded, err := c.Lavalink.BestNode().Rest().DecodeTracks(ctx, []string{track}) + decoded, err := c.Lavalink.BestNode().Rest().DecodeTrack(ctx, track) if err != nil { _, err = e.UpdateInteractionResponse(discord.MessageUpdate{ Content: json.Ptr(fmt.Sprintf("failed to decode track: %s", err)), @@ -39,7 +66,7 @@ func (c *Commands) Decode(e *handler.CommandEvent) error { return err } - data, err := json.MarshalIndent(decoded[0], "", " ") + data, err := json.MarshalIndent(decoded, "", " ") if err != nil { _, err = e.UpdateInteractionResponse(discord.MessageUpdate{ Content: json.Ptr(fmt.Sprintf("failed to decode track: %s", err)), diff --git a/internal/trackdecode/decode.go b/internal/trackdecode/decode.go new file mode 100644 index 0000000..396a75c --- /dev/null +++ b/internal/trackdecode/decode.go @@ -0,0 +1,239 @@ +package trackdecode + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/disgoorg/disgolink/v3/lavalink" + "github.com/disgoorg/json" + "github.com/disgoorg/lavasrc-plugin" +) + +const trackInfoVersioned int32 = 1 + +func decodeLavaSrcExtendedFields(track *lavalink.Track, r io.Reader) error { + var info lavasrc.TrackInfo + defer func() { + raw, _ := json.Marshal(info) + track.PluginInfo = raw + }() + + albumName, err := readNullableString(r) + if err != nil { + return fmt.Errorf("failed to read track album name: %w", err) + } + if albumName != nil { + info.AlbumName = *albumName + } + + albumURL, err := readNullableString(r) + if err != nil { + return fmt.Errorf("failed to read track album url: %w", err) + } + if albumURL != nil { + info.ArtistURL = *albumURL + } + + artistURL, err := readNullableString(r) + if err != nil { + return fmt.Errorf("failed to read track artist url: %w", err) + } + if artistURL != nil { + info.ArtistURL = *artistURL + } + + artistArtworkURL, err := readNullableString(r) + if err != nil { + return fmt.Errorf("failed to read track artist artwork url: %w", err) + } + if artistArtworkURL != nil { + info.ArtistArtworkURL = *artistArtworkURL + } + + previewURL, err := readNullableString(r) + if err != nil { + return fmt.Errorf("failed to read track preview url: %w", err) + } + if previewURL != nil { + info.PreviewURL = *previewURL + } + + info.IsPreview, err = readBool(r) + if err != nil { + return fmt.Errorf("failed to read track is preview: %w", err) + } + + return nil +} + +type probeInfo struct { + ProbeInfo string `json:"probeInfo"` +} + +func decodeProbeInfo(track *lavalink.Track, r io.Reader) error { + info, err := readString(r) + if err != nil { + return fmt.Errorf("failed to read track probe info: %w", err) + } + + raw, _ := json.Marshal(probeInfo{ + ProbeInfo: info, + }) + track.PluginInfo = raw + return nil +} + +var customTrackDecoders = map[string]func(track *lavalink.Track, r io.Reader) error{ + "deezer": func(track *lavalink.Track, r io.Reader) error { + return decodeLavaSrcExtendedFields(track, r) + }, + "spotify": func(track *lavalink.Track, r io.Reader) error { + return decodeLavaSrcExtendedFields(track, r) + }, + "applemusic": func(track *lavalink.Track, r io.Reader) error { + return decodeLavaSrcExtendedFields(track, r) + }, + "http": func(track *lavalink.Track, r io.Reader) error { + return decodeProbeInfo(track, r) + }, + "local": func(track *lavalink.Track, r io.Reader) error { + return decodeProbeInfo(track, r) + }, +} + +func DecodeString(encoded string) (*lavalink.Track, int, error) { + data, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, 0, fmt.Errorf("invalid base64: %w", err) + } + + var ( + r = bytes.NewReader(data) + track = &lavalink.Track{ + Encoded: encoded, + PluginInfo: []byte("{}"), + UserData: []byte("{}"), + } + value int32 + ) + + if value, err = readInt32(r); err != nil { + return nil, 0, fmt.Errorf("failed to read track flags: %w", err) + } + + flags := int32(int64(value) & 0xC0000000 >> 30) + messageSize := int64(value) & 0x3FFFFFFF + if messageSize == 0 { + return nil, 0, errors.New("message size: 0") + } + + var version int + if flags&trackInfoVersioned == 0 { + version = 1 + } else { + var v uint8 + if v, err = readUInt8(r); err != nil { + return nil, 0, fmt.Errorf("failed to read track version: %w", err) + } + version = int(v) + } + + if track.Info.Title, err = readString(r); err != nil { + return track, version, fmt.Errorf("failed to read track title: %w", err) + } + if track.Info.Author, err = readString(r); err != nil { + return track, version, fmt.Errorf("failed to read track author: %w", err) + } + + var length int64 + if length, err = readInt64(r); err != nil { + return track, version, fmt.Errorf("failed to read track length: %w", err) + } + track.Info.Length = lavalink.Duration(length) + + if track.Info.Identifier, err = readString(r); err != nil { + return track, version, fmt.Errorf("failed to read track identifier: %w", err) + } + if track.Info.IsStream, err = readBool(r); err != nil { + return track, version, fmt.Errorf("failed to read track is stream: %w", err) + } + if version >= 2 { + if track.Info.URI, err = readNullableString(r); err != nil { + return track, version, fmt.Errorf("failed to read track uri: %w", err) + } + } + if version >= 3 { + if track.Info.ArtworkURL, err = readNullableString(r); err != nil { + return track, version, fmt.Errorf("failed to read track artwork url: %w", err) + } + if track.Info.ISRC, err = readNullableString(r); err != nil { + return track, version, fmt.Errorf("failed to read track isrc: %w", err) + } + } + if track.Info.SourceName, err = readString(r); err != nil { + return track, version, fmt.Errorf("failed to read track source name: %w", err) + } + + if decoder, ok := customTrackDecoders[track.Info.SourceName]; ok { + if err = decoder(track, r); err != nil { + return track, version, fmt.Errorf("failed to read track source fields: %w", err) + } + } + + var position int64 + if position, err = readInt64(r); err != nil { + return track, version, fmt.Errorf("failed to read track position: %w", err) + } + track.Info.Position = lavalink.Duration(position) + + return track, version, nil +} + +func readInt64(r io.Reader) (i int64, err error) { + return i, binary.Read(r, binary.BigEndian, &i) +} + +func readInt32(r io.Reader) (i int32, err error) { + return i, binary.Read(r, binary.BigEndian, &i) +} + +func readInt16(r io.Reader) (i int16, err error) { + return i, binary.Read(r, binary.BigEndian, &i) +} + +func readUInt8(r io.Reader) (i uint8, err error) { + return i, binary.Read(r, binary.BigEndian, &i) +} + +func readBool(r io.Reader) (b bool, err error) { + return b, binary.Read(r, binary.BigEndian, &b) +} + +func readString(r io.Reader) (string, error) { + size, err := readInt16(r) + if err != nil { + return "", err + } + b := make([]byte, size) + if err = binary.Read(r, binary.BigEndian, &b); err != nil { + return "", err + } + return string(b), nil +} + +func readNullableString(r io.Reader) (*string, error) { + b, err := readBool(r) + if err != nil || !b { + return nil, err + } + + s, err := readString(r) + if err != nil { + return nil, err + } + return &s, nil +}