Skip to content

Commit

Permalink
Improve id validation
Browse files Browse the repository at this point in the history
  • Loading branch information
AgustinSRG committed Jan 24, 2025
1 parent a8911b9 commit 64f2e55
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 29 deletions.
41 changes: 29 additions & 12 deletions rtmp_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ type RTMPServer struct {
sessions map[uint64]*RTMPSession // Active sessions
channels map[string]*RTMPChannel // Active streaming channels

ip_limit uint32 // Max number of active sessions
ip_count map[string]uint32 // Mapping IP -> Number of active sessions
streamIdMaxLength int // Max length for stream IDs, rooms and keys

ipLimit uint32 // Max number of active sessions
ipCount map[string]uint32 // Mapping IP -> Number of active sessions

ip_mutex *sync.Mutex // Mutex for the IP count mapping

Expand All @@ -57,6 +59,10 @@ type RTMPServer struct {
closed bool // True if the server is closed
}

const STREAM_ID_DEFAULT_MAX_LENGTH = 128
const GOP_CACHE_DEFAULT_LIMIT = 256 * 1024 * 1024
const IP_DEFAULT_LIMIT = 4

// Creates a RTMP server using the configuration from the environment variables
func CreateRTMPServer() *RTMPServer {
server := RTMPServer{
Expand All @@ -70,17 +76,18 @@ func CreateRTMPServer() *RTMPServer {
channels: make(map[string]*RTMPChannel),
next_session_id: 1,
closed: false,
ip_count: make(map[string]uint32),
ip_limit: 4,
gopCacheLimit: 256 * 1024 * 1024,
ipCount: make(map[string]uint32),
ipLimit: IP_DEFAULT_LIMIT,
gopCacheLimit: GOP_CACHE_DEFAULT_LIMIT,
websocketControlConnection: nil,
streamIdMaxLength: STREAM_ID_DEFAULT_MAX_LENGTH,
}

custom_ip_limit := os.Getenv("MAX_IP_CONCURRENT_CONNECTIONS")
if custom_ip_limit != "" {
cil, e := strconv.Atoi(custom_ip_limit)
if e != nil {
server.ip_limit = uint32(cil)
server.ipLimit = uint32(cil)
}
}

Expand Down Expand Up @@ -180,6 +187,16 @@ func CreateRTMPServer() *RTMPServer {
}
}

idCustomMaxLength := os.Getenv("ID_MAX_LENGTH")

if idCustomMaxLength != "" {
var e error
idMaxLen, e := strconv.Atoi(idCustomMaxLength)
if e == nil && idMaxLen > 0 {
server.streamIdMaxLength = idMaxLen
}
}

if os.Getenv("CONTROL_USE") == "YES" {
server.websocketControlConnection = &ControlServerConnection{}
}
Expand All @@ -194,13 +211,13 @@ func (server *RTMPServer) AddIP(ip string) bool {
server.ip_mutex.Lock()
defer server.ip_mutex.Unlock()

c := server.ip_count[ip]
c := server.ipCount[ip]

if c >= server.ip_limit {
if c >= server.ipLimit {
return false
}

server.ip_count[ip] = c + 1
server.ipCount[ip] = c + 1

return true
}
Expand Down Expand Up @@ -246,12 +263,12 @@ func (server *RTMPServer) RemoveIP(ip string) {
server.ip_mutex.Lock()
defer server.ip_mutex.Unlock()

c := server.ip_count[ip]
c := server.ipCount[ip]

if c <= 1 {
delete(server.ip_count, ip)
delete(server.ipCount, ip)
} else {
server.ip_count[ip] = c - 1
server.ipCount[ip] = c - 1
}
}

Expand Down
4 changes: 2 additions & 2 deletions rtmp_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ func (s *RTMPSession) HandleConnect(cmd *RTMPCommand) bool {
s.channel = cmd.GetArg("cmdObj").GetProperty("app").GetString()

// Validate channel
if !validateStreamIDString(s.channel) {
if !validateStreamIDString(s.channel, s.server.streamIdMaxLength) {
LogRequest(s.id, s.ip, "INVALID CHANNEL '"+s.channel+"'")
return false
}
Expand Down Expand Up @@ -587,7 +587,7 @@ func (s *RTMPSession) HandlePublish(cmd *RTMPCommand, packet *RTMPPacket) bool {
}

// Validate key
if !validateStreamIDString(s.key) {
if !validateStreamIDString(s.key, s.server.streamIdMaxLength) {
s.SendStatusMessage(s.publishStreamId, "error", "NetStream.Publish.BadName", "Invalid stream key provided")
return false
}
Expand Down
17 changes: 2 additions & 15 deletions rtmp_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
package main

import (
"os"
"regexp"
"strconv"
"strings"
)

Expand Down Expand Up @@ -272,19 +270,8 @@ func decodeRTMPData(data []byte) RTMPData {
// Validates stream ID
// str - Stream ID
// Returns true only if valid
func validateStreamIDString(str string) bool {
var ID_MAX_LENGTH = 128
idCustomMaxLength := os.Getenv("ID_MAX_LENGTH")

if idCustomMaxLength != "" {
var e error
ID_MAX_LENGTH, e = strconv.Atoi(idCustomMaxLength)
if e != nil {
ID_MAX_LENGTH = 128
}
}

if len(str) > ID_MAX_LENGTH {
func validateStreamIDString(str string, maxLength int) bool {
if len(str) > maxLength {
return false
}

Expand Down

0 comments on commit 64f2e55

Please sign in to comment.