Skip to content

Commit

Permalink
fix code scanning warning: uncontrolled data used in path expression.
Browse files Browse the repository at this point in the history
  • Loading branch information
suzp1984 committed Oct 30, 2024
1 parent d1ffc2a commit 6eae8fe
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 155 deletions.
20 changes: 8 additions & 12 deletions platform/dvr-local-disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package main

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -535,21 +536,16 @@ func (v *RecordWorker) Handle(ctx context.Context, handler *http.ServeMux) error
return nil
}

func (v *RecordWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error {
func (v *RecordWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error {
// Copy the ts file to temporary cache dir.
tsid := uuid.NewString()
tsfile := path.Join("record", fmt.Sprintf("%v.ts", tsid))

// Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/
// Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS.
if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil {
return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile)
}

// Get the file size.
stats, err := os.Stat(msg.File)
if err != nil {
return errors.Wrapf(err, "stat file %v", msg.File)
if file, err := os.Create(tsfile); err != nil {
return errors.Wrapf(err, "create file %v error", tsfile)
} else {
defer file.Close()
io.Copy(file, bytes.NewReader(data))
}

// Create a local ts file object.
Expand All @@ -558,7 +554,7 @@ func (v *RecordWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage)
URL: msg.URL,
SeqNo: msg.SeqNo,
Duration: msg.Duration,
Size: uint64(stats.Size()),
Size: uint64(len(data)),
File: tsfile,
}

Expand Down
22 changes: 9 additions & 13 deletions platform/dvr-tencent-cos.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
package main

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"strings"
"sync"
Expand Down Expand Up @@ -231,7 +232,7 @@ func (v *DvrWorker) Handle(ctx context.Context, handler *http.ServeMux) error {
return nil
}

func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error {
func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error {
// Ignore for Tencent Cloud credentials not ready.
if !v.ready() {
return nil
Expand All @@ -241,16 +242,11 @@ func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er
tsid := uuid.NewString()
tsfile := path.Join("dvr", fmt.Sprintf("%v.ts", tsid))

// Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/
// Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS.
if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil {
return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile)
}

// Get the file size.
stats, err := os.Stat(msg.File)
if err != nil {
return errors.Wrapf(err, "stat file %v", msg.File)
if file, err := os.Create(tsfile); err != nil {
return errors.Wrapf(err, "create file %v error", tsfile)
} else {
defer file.Close()
io.Copy(file, bytes.NewReader(data))
}

// Create a local ts file object.
Expand All @@ -259,7 +255,7 @@ func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er
URL: msg.URL,
SeqNo: msg.SeqNo,
Duration: msg.Duration,
Size: uint64(stats.Size()),
Size: uint64(len(data)),
File: tsfile,
}

Expand Down
22 changes: 9 additions & 13 deletions platform/dvr-tencent-vod.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
package main

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -321,7 +322,7 @@ func (v *VodWorker) Handle(ctx context.Context, handler *http.ServeMux) error {
return nil
}

func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error {
func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error {
// Ignore for Tencent Cloud credentials not ready.
if !v.ready() {
return nil
Expand All @@ -331,16 +332,11 @@ func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er
tsid := uuid.NewString()
tsfile := path.Join("vod", fmt.Sprintf("%v.ts", tsid))

// Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/
// Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS.
if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil {
return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile)
}

// Get the file size.
stats, err := os.Stat(msg.File)
if err != nil {
return errors.Wrapf(err, "stat file %v", msg.File)
if file, err := os.Create(tsfile); err != nil {
return errors.Wrapf(err, "create file %v error", tsfile)
} else {
defer file.Close()
io.Copy(file, bytes.NewReader(data))
}

// Create a local ts file object.
Expand All @@ -349,7 +345,7 @@ func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er
URL: msg.URL,
SeqNo: msg.SeqNo,
Duration: msg.Duration,
Size: uint64(stats.Size()),
Size: uint64(len(data)),
File: tsfile,
}

Expand Down
51 changes: 9 additions & 42 deletions platform/ocr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package main

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
Expand All @@ -21,6 +22,7 @@ import (
"github.com/ossrs/go-oryx-lib/errors"
ohttp "github.com/ossrs/go-oryx-lib/http"
"github.com/ossrs/go-oryx-lib/logger"

// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
Expand All @@ -39,17 +41,12 @@ type OCRWorker struct {
// The global OCR task, only support one OCR task.
task *OCRTask

// Use async goroutine to process on_hls messages.
msgs chan *SrsOnHlsMessage

// Got message from SRS, a new TS segment file is generated.
tsfiles chan *SrsOnHlsObject
}

func NewOCRWorker() *OCRWorker {
v := &OCRWorker{
// Message on_hls.
msgs: make(chan *SrsOnHlsMessage, 1024),
// TS files.
tsfiles: make(chan *SrsOnHlsObject, 1024),
}
Expand Down Expand Up @@ -547,16 +544,7 @@ func (v *OCRWorker) Enabled() bool {
return v.task.enabled()
}

func (v *OCRWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error {
select {
case <-ctx.Done():
case v.msgs <- msg:
}

return nil
}

func (v *OCRWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHlsMessage) error {
func (v *OCRWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error {
// Ignore if not natch the task config.
if !v.task.match(msg) {
return nil
Expand All @@ -566,16 +554,11 @@ func (v *OCRWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHlsMessage
tsid := fmt.Sprintf("%v-org-%v", msg.SeqNo, uuid.NewString())
tsfile := path.Join("ocr", fmt.Sprintf("%v.ts", tsid))

// Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/
// Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS.
if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil {
return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile)
}

// Get the file size.
stats, err := os.Stat(msg.File)
if err != nil {
return errors.Wrapf(err, "stat file %v", msg.File)
if file, err := os.Create(tsfile); err != nil {
return errors.Wrapf(err, "create file %v error", tsfile)
} else {
defer file.Close()
io.Copy(file, bytes.NewReader(data))
}

// Create a local ts file object.
Expand All @@ -584,7 +567,7 @@ func (v *OCRWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHlsMessage
URL: msg.URL,
SeqNo: msg.SeqNo,
Duration: msg.Duration,
Size: uint64(stats.Size()),
Size: uint64(len(data)),
File: tsfile,
}

Expand Down Expand Up @@ -659,22 +642,6 @@ func (v *OCRWorker) Start(ctx context.Context) error {
}
}()

// Consume all on_hls messages.
wg.Add(1)
go func() {
defer wg.Done()

for ctx.Err() == nil {
select {
case <-ctx.Done():
case msg := <-v.msgs:
if err := v.OnHlsTsMessageImpl(ctx, msg); err != nil {
logger.Wf(ctx, "ocr: handle on hls message %v err %+v", msg.String(), err)
}
}
}
}()

// Consume all ts files by task.
wg.Add(1)
go func() {
Expand Down
78 changes: 46 additions & 32 deletions platform/srs-hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -730,46 +730,60 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error {
return errors.Errorf("invalid action=%v", msg.Action)
}

if _, err := os.Stat(msg.File); err != nil {
logger.Tf(ctx, "invalid ts file %v", msg.File)
// below stupied code is used to resolve the code scanning error:
// Uncontrolled data used in path expression
allowedDir, err := os.Getwd()
if err != nil {
return errors.Wrapf(err, "can not get current working directory")
}

if err := os.MkdirAll(filepath.Dir(msg.File), 0755); err != nil {
return errors.Wrapf(err, "failed to create ts file directory %v", filepath.Dir(msg.File))
}
safePath := filepath.Join(allowedDir, filepath.Clean(msg.File))
logger.Tf(ctx, "safePath is %v", safePath)
absPath, err := filepath.Abs(safePath)
if err != nil {
return errors.Wrapf(err, "can not get absolute path from %v", safePath)
}

if tsFile, err := os.Create(msg.File); err != nil {
return errors.Wrapf(err, "failed to create ts file %v", msg.File)
} else {
tsUrl := "http://" + os.Getenv("SRS_HOST") + ":" + os.Getenv("SRS_HTTP_STREAM_PORT") + "/" + msg.URL
logger.Tf(ctx, "download ts from %v", tsUrl)
client := http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
}
if !strings.HasPrefix(absPath, allowedDir) {
return errors.Errorf("Access denied, %v is outside allowed directory", absPath)
}

resp, err := client.Get(tsUrl)
if err != nil {
return errors.Wrapf(err, "http error to get url %v", tsUrl)
}
defer resp.Body.Close()
var data []byte
if _, err := os.Stat(absPath); err != nil {
logger.Tf(ctx, "invalid ts file %v", absPath)
tsUrl := "http://" + os.Getenv("SRS_HOST") + ":" + os.Getenv("SRS_HTTP_STREAM_PORT") + "/" + msg.URL
logger.Tf(ctx, "download ts from %v", tsUrl)
client := http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
}

size, err := io.Copy(tsFile, resp.Body)
if err != nil {
return errors.Wrapf(err, "copy http resp to file %v", tsFile)
}
defer tsFile.Close()
logger.Tf(ctx, "Download ts file %s with size %d", tsUrl, size)
res, err := client.Get(tsUrl)
if err != nil {
return errors.Wrapf(err, "http error to get url %v", tsUrl)
}
defer res.Body.Close()

if b, err := io.ReadAll(res.Body); err != nil {
return errors.Wrapf(err, "read http response error")
} else {
data = b
}
logger.Tf(ctx, "Download ts file %s with size %d", tsUrl, len(data))
} else if b, err := os.ReadFile(absPath); err != nil {
return errors.Wrapf(err, "read %v error", absPath)
} else {
data = b
}
logger.Tf(ctx, "on_hls ok, %v", string(b))

// Handle TS file by Record task if enabled.
if recordAll, err := rdb.HGet(ctx, SRS_RECORD_PATTERNS, "all").Result(); err != nil && err != redis.Nil {
return errors.Wrapf(err, "hget %v all", SRS_RECORD_PATTERNS)
} else if recordAll == "true" {
if err = recordWorker.OnHlsTsMessage(ctx, &msg); err != nil {
if err = recordWorker.OnHlsTsMessage(ctx, &msg, data); err != nil {
return errors.Wrapf(err, "feed %v", msg.String())
}
logger.Tf(ctx, "record %v", msg.String())
Expand All @@ -779,7 +793,7 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error {
if dvrAll, err := rdb.HGet(ctx, SRS_DVR_PATTERNS, "all").Result(); err != nil && err != redis.Nil {
return errors.Wrapf(err, "hget %v all", SRS_DVR_PATTERNS)
} else if dvrAll == "true" {
if err = dvrWorker.OnHlsTsMessage(ctx, &msg); err != nil {
if err = dvrWorker.OnHlsTsMessage(ctx, &msg, data); err != nil {
return errors.Wrapf(err, "feed %v", msg.String())
}
logger.Tf(ctx, "dvr %v", msg.String())
Expand All @@ -789,23 +803,23 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error {
if vodAll, err := rdb.HGet(ctx, SRS_VOD_PATTERNS, "all").Result(); err != nil && err != redis.Nil {
return errors.Wrapf(err, "hget %v all", SRS_VOD_PATTERNS)
} else if vodAll == "true" {
if err = vodWorker.OnHlsTsMessage(ctx, &msg); err != nil {
if err = vodWorker.OnHlsTsMessage(ctx, &msg, data); err != nil {
return errors.Wrapf(err, "feed %v", msg.String())
}
logger.Tf(ctx, "vod %v", msg.String())
}

// Handle TS file by Transcript task if enabled.
if transcriptWorker.Enabled() {
if err = transcriptWorker.OnHlsTsMessage(ctx, &msg); err != nil {
if err = transcriptWorker.OnHlsTsMessage(ctx, &msg, data); err != nil {
return errors.Wrapf(err, "feed %v", msg.String())
}
logger.Tf(ctx, "transcript %v", msg.String())
}

// Handle TS file by OCR task if enabled.
if ocrWorker.Enabled() {
if err = ocrWorker.OnHlsTsMessage(ctx, &msg); err != nil {
if err = ocrWorker.OnHlsTsMessage(ctx, &msg, data); err != nil {
return errors.Wrapf(err, "feed %v", msg.String())
}
logger.Tf(ctx, "ocr %v", msg.String())
Expand Down
Loading

0 comments on commit 6eae8fe

Please sign in to comment.