From 64f2e55134946b33e144775e94b521bcf3612c4d Mon Sep 17 00:00:00 2001 From: AgustinSRG Date: Fri, 24 Jan 2025 21:47:34 +0100 Subject: [PATCH] Improve id validation --- rtmp_server.go | 41 +++++++++++++++++++++++++++++------------ rtmp_session.go | 4 ++-- rtmp_utils.go | 17 ++--------------- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/rtmp_server.go b/rtmp_server.go index 7ccac66..e68f8de 100644 --- a/rtmp_server.go +++ b/rtmp_server.go @@ -44,8 +44,10 @@ type RTMPServer struct { sessions map[uint64]*RTMPSession // Active sessions channels map[string]*RTMPChannel // Active streaming channels - ip_limit uint32 // Max number of active sessions - ip_count map[string]uint32 // Mapping IP -> Number of active sessions + streamIdMaxLength int // Max length for stream IDs, rooms and keys + + ipLimit uint32 // Max number of active sessions + ipCount map[string]uint32 // Mapping IP -> Number of active sessions ip_mutex *sync.Mutex // Mutex for the IP count mapping @@ -57,6 +59,10 @@ type RTMPServer struct { closed bool // True if the server is closed } +const STREAM_ID_DEFAULT_MAX_LENGTH = 128 +const GOP_CACHE_DEFAULT_LIMIT = 256 * 1024 * 1024 +const IP_DEFAULT_LIMIT = 4 + // Creates a RTMP server using the configuration from the environment variables func CreateRTMPServer() *RTMPServer { server := RTMPServer{ @@ -70,17 +76,18 @@ func CreateRTMPServer() *RTMPServer { channels: make(map[string]*RTMPChannel), next_session_id: 1, closed: false, - ip_count: make(map[string]uint32), - ip_limit: 4, - gopCacheLimit: 256 * 1024 * 1024, + ipCount: make(map[string]uint32), + ipLimit: IP_DEFAULT_LIMIT, + gopCacheLimit: GOP_CACHE_DEFAULT_LIMIT, websocketControlConnection: nil, + streamIdMaxLength: STREAM_ID_DEFAULT_MAX_LENGTH, } custom_ip_limit := os.Getenv("MAX_IP_CONCURRENT_CONNECTIONS") if custom_ip_limit != "" { cil, e := strconv.Atoi(custom_ip_limit) if e != nil { - server.ip_limit = uint32(cil) + server.ipLimit = uint32(cil) } } @@ -180,6 +187,16 @@ func CreateRTMPServer() *RTMPServer { } } + idCustomMaxLength := os.Getenv("ID_MAX_LENGTH") + + if idCustomMaxLength != "" { + var e error + idMaxLen, e := strconv.Atoi(idCustomMaxLength) + if e == nil && idMaxLen > 0 { + server.streamIdMaxLength = idMaxLen + } + } + if os.Getenv("CONTROL_USE") == "YES" { server.websocketControlConnection = &ControlServerConnection{} } @@ -194,13 +211,13 @@ func (server *RTMPServer) AddIP(ip string) bool { server.ip_mutex.Lock() defer server.ip_mutex.Unlock() - c := server.ip_count[ip] + c := server.ipCount[ip] - if c >= server.ip_limit { + if c >= server.ipLimit { return false } - server.ip_count[ip] = c + 1 + server.ipCount[ip] = c + 1 return true } @@ -246,12 +263,12 @@ func (server *RTMPServer) RemoveIP(ip string) { server.ip_mutex.Lock() defer server.ip_mutex.Unlock() - c := server.ip_count[ip] + c := server.ipCount[ip] if c <= 1 { - delete(server.ip_count, ip) + delete(server.ipCount, ip) } else { - server.ip_count[ip] = c - 1 + server.ipCount[ip] = c - 1 } } diff --git a/rtmp_session.go b/rtmp_session.go index 0b64120..a0682ab 100644 --- a/rtmp_session.go +++ b/rtmp_session.go @@ -541,7 +541,7 @@ func (s *RTMPSession) HandleConnect(cmd *RTMPCommand) bool { s.channel = cmd.GetArg("cmdObj").GetProperty("app").GetString() // Validate channel - if !validateStreamIDString(s.channel) { + if !validateStreamIDString(s.channel, s.server.streamIdMaxLength) { LogRequest(s.id, s.ip, "INVALID CHANNEL '"+s.channel+"'") return false } @@ -587,7 +587,7 @@ func (s *RTMPSession) HandlePublish(cmd *RTMPCommand, packet *RTMPPacket) bool { } // Validate key - if !validateStreamIDString(s.key) { + if !validateStreamIDString(s.key, s.server.streamIdMaxLength) { s.SendStatusMessage(s.publishStreamId, "error", "NetStream.Publish.BadName", "Invalid stream key provided") return false } diff --git a/rtmp_utils.go b/rtmp_utils.go index 80f247b..aaeb293 100644 --- a/rtmp_utils.go +++ b/rtmp_utils.go @@ -3,9 +3,7 @@ package main import ( - "os" "regexp" - "strconv" "strings" ) @@ -272,19 +270,8 @@ func decodeRTMPData(data []byte) RTMPData { // Validates stream ID // str - Stream ID // Returns true only if valid -func validateStreamIDString(str string) bool { - var ID_MAX_LENGTH = 128 - idCustomMaxLength := os.Getenv("ID_MAX_LENGTH") - - if idCustomMaxLength != "" { - var e error - ID_MAX_LENGTH, e = strconv.Atoi(idCustomMaxLength) - if e != nil { - ID_MAX_LENGTH = 128 - } - } - - if len(str) > ID_MAX_LENGTH { +func validateStreamIDString(str string, maxLength int) bool { + if len(str) > maxLength { return false }