Skip to content

Commit

Permalink
fixing security issue
Browse files Browse the repository at this point in the history
Signed-off-by: Atul-source <atulprajapati6031@gmail.com>
  • Loading branch information
Atul-source committed Sep 3, 2024
1 parent 4645f1e commit 656c654
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 45 deletions.
66 changes: 54 additions & 12 deletions apis/handlers/restart_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import (
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"syscall"
"time"

"net/http"
"net/url"

"github.com/rs/zerolog/log"

Expand Down Expand Up @@ -71,6 +73,46 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc {
statusCode = http.StatusInternalServerError
return
}

match, _ := regexp.MatchString(`^v\d+\.\d+\.\d+$`, t.Version)
if !match {
mesg = fmt.Sprintf("version naming convention is wrong it will like vx.y.z")

Check failure on line 79 in apis/handlers/restart_linux.go

View workflow job for this annotation

GitHub Actions / build

unnecessary use of fmt.Sprintf (S1039)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
}
machineHostname, err := os.Hostname()
if err != nil {
mesg = fmt.Sprintf("failed to get os hostname")

Check failure on line 86 in apis/handlers/restart_linux.go

View workflow job for this annotation

GitHub Actions / build

unnecessary use of fmt.Sprintf (S1039)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
}
if machineHostname != t.HostName {
mesg = fmt.Sprintf("this api request is not for provided host")

Check failure on line 92 in apis/handlers/restart_linux.go

View workflow job for this annotation

GitHub Actions / build

unnecessary use of fmt.Sprintf (S1039)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
}
URL, err := url.Parse(t.ArtifactURL)
if err != nil {
mesg = fmt.Sprintf("url format is wrong")

Check failure on line 99 in apis/handlers/restart_linux.go

View workflow job for this annotation

GitHub Actions / build

unnecessary use of fmt.Sprintf (S1039)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
}
if URL.Scheme != models.HttpScheme && URL.Scheme != models.FileScheme && URL.Scheme != models.HttpsScheme {
mesg = fmt.Sprintf("currently only http,https,file is supported")

Check failure on line 105 in apis/handlers/restart_linux.go

View workflow job for this annotation

GitHub Actions / build

unnecessary use of fmt.Sprintf (S1039)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
}
if strings.Contains(t.ArtifactURL, "..") {
mesg = fmt.Sprintf("bad string")

Check failure on line 111 in apis/handlers/restart_linux.go

View workflow job for this annotation

GitHub Actions / build

unnecessary use of fmt.Sprintf (S1039)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
}
defer func() {
models.IsReadOnly = false
}()
Expand All @@ -88,14 +130,14 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc {

err, oldCfgPath := restart.ReadSymlink(bpfcfg.HostConfig.BasePath + "/latest/l3afd.cfg")
if err != nil {
mesg = fmt.Sprintf("failed read simlink: %v", err)
mesg = fmt.Sprintf("failed read symlink: %v", err)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
}
err, oldBinPath := restart.ReadSymlink(bpfcfg.HostConfig.BasePath + "/latest/l3afd")
if err != nil {
mesg = fmt.Sprintf("failed to read simlink: %v", err)
mesg = fmt.Sprintf("failed to read symlink: %v", err)
log.Error().Msg(mesg)
statusCode = http.StatusInternalServerError
return
Expand Down Expand Up @@ -179,22 +221,22 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc {
err = cmd.Start()
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to start new instance %v", err)
mesg = mesg + fmt.Sprintf("unable to start new instance %v", err)
// write a function a to do cleanup of other process if necessary
err = cmd.Process.Kill()
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to kill the new instance %v", err)
mesg = mesg + fmt.Sprintf("unable to kill the new instance %v", err)
}
err = bpfcfg.StartAllUserProgramsAndProbes()
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to start all userprograms and probes: %v", err)
mesg = mesg + fmt.Sprintf("unable to start all userprograms and probes: %v", err)
}
err = pidfile.CreatePID(bpfcfg.HostConfig.PIDFilename)
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to create pid file: %v", err)
mesg = mesg + fmt.Sprintf("unable to create pid file: %v", err)
}
err = restart.RollBackSymlink(oldCfgPath, oldBinPath, oldVersion, t.Version, bpfcfg.HostConfig)
if err != nil {
Expand Down Expand Up @@ -242,17 +284,17 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc {
err = cmd.Process.Kill()
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to kill the new instance %v", err)
mesg = mesg + fmt.Sprintf("unable to kill the new instance %v", err)
}
err = bpfcfg.StartAllUserProgramsAndProbes()
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to start all userprograms and probes: %v", err)
mesg = mesg + fmt.Sprintf("unable to start all userprograms and probes: %v", err)
}
err = pidfile.CreatePID(bpfcfg.HostConfig.PIDFilename)
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to create pid file: %v", err)
mesg = mesg + fmt.Sprintf("unable to create pid file: %v", err)
}
err = restart.RollBackSymlink(oldCfgPath, oldBinPath, oldVersion, t.Version, bpfcfg.HostConfig)
if err != nil {
Expand All @@ -273,17 +315,17 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc {
err = cmd.Process.Kill()
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to kill the new instance %v", err)
mesg = mesg + fmt.Sprintf("unable to kill the new instance %v", err)
}
err = bpfcfg.StartAllUserProgramsAndProbes()
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to start all userprograms and probes: %v", err)
mesg = mesg + fmt.Sprintf("unable to start all userprograms and probes: %v", err)
}
err = pidfile.CreatePID(bpfcfg.HostConfig.PIDFilename)
if err != nil {
log.Error().Msgf("%v", err)
mesg = mesg + fmt.Sprintf("not able to create pid file: %v", err)
mesg = mesg + fmt.Sprintf("unable to create pid file: %v", err)
}
err = restart.RollBackSymlink(oldCfgPath, oldBinPath, oldVersion, t.Version, bpfcfg.HostConfig)
if err != nil {
Expand Down
11 changes: 7 additions & 4 deletions bpfprogs/bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ func (b *BPF) GetArtifacts(conf *config.Config) error {
tempDir := filepath.Join(conf.BPFDir, b.Program.Name, b.Program.Version)
err = ExtractArtifact(b.Program.Artifact, buf, tempDir)
if err != nil {
return fmt.Errorf("not able to extract artifact %w", err)
return fmt.Errorf("unable to extract artifact %w", err)
}
newDir := strings.Split(b.Program.Artifact, ".")
b.FilePath = filepath.Join(tempDir, newDir[0])
Expand All @@ -646,9 +646,12 @@ func DownloadArtifact(URL *url.URL, timeout time.Duration, buf *bytes.Buffer) er
ResponseHeaderTimeout: timeOut,
}
client := http.Client{Transport: netTransport, Timeout: timeOut}

path := URL.String()
if !strings.HasPrefix(path, models.HttpScheme+"://") {
return fmt.Errorf("malicious path")
}
// Get the data
resp, err := client.Get(URL.String())
resp, err := client.Get(path)

Check failure

Code scanning / CodeQL

Uncontrolled data used in network request Critical

The
URL
of this request depends on a
user-provided value
.
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
Expand Down Expand Up @@ -871,7 +874,7 @@ func (b *BPF) MonitorMaps(ifaceName string, intervals int) error {
_, ok := b.MetricsBpfMaps[mapKey]
if !ok {
if err := b.AddMetricsBPFMap(element.Name, element.Aggregator, element.Key, intervals); err != nil {
return fmt.Errorf("not able to fetch map %s key %d aggregator %s : %w", element.Name, element.Key, element.Aggregator, err)
return fmt.Errorf("unable to fetch map %s key %d aggregator %s : %w", element.Name, element.Key, element.Aggregator, err)
}
}
bpfMap := b.MetricsBpfMaps[mapKey]
Expand Down
2 changes: 1 addition & 1 deletion bpfprogs/bpfdebug.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func SetupBPFDebug(ebpfChainDebugAddr string, BPFConfigs *NFConfigs) {
}
listener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
log.Fatal().Err(err).Msgf("Not able to create net Listen")
log.Fatal().Err(err).Msgf("unable to create net Listen")
}
models.AllNetListeners.Store("debug_http", listener)
}
Expand Down
6 changes: 3 additions & 3 deletions bpfprogs/nfconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ func (c *NFConfigs) StartAllUserProgramsAndProbes() error {
prg := b.ProgMapCollection
b.ProgMapCollection = nil
if err := b.LoadBPFProgram(iface); err != nil {
return fmt.Errorf("not able to load probes %w", err)
return fmt.Errorf("unable to load probes %w", err)
}
b.Program.EntryFunctionName = ef
if b.ProgMapCollection != nil {
Expand Down Expand Up @@ -1667,7 +1667,7 @@ func (c *NFConfigs) StartAllUserProgramsAndProbes() error {
prg := b.ProgMapCollection
b.ProgMapCollection = nil
if err := b.LoadBPFProgram(iface); err != nil {
return fmt.Errorf("not able to load probes %w", err)
return fmt.Errorf("unable to load probes %w", err)
}
b.Program.EntryFunctionName = ef
if b.ProgMapCollection != nil {
Expand Down Expand Up @@ -1711,7 +1711,7 @@ func (c *NFConfigs) StartAllUserProgramsAndProbes() error {
prg := b.ProgMapCollection
b.ProgMapCollection = nil
if err := b.LoadBPFProgram(iface); err != nil {
return fmt.Errorf("not able to load probes %w", err)
return fmt.Errorf("unable to load probes %w", err)
}
b.Program.EntryFunctionName = ef
if b.ProgMapCollection != nil {
Expand Down
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,14 @@ func setupForRestart(ctx context.Context, conf *config.Config) error {
models.IsReadOnly = true
// Now you need to write client side code
conn, err := net.Dial("unix", models.HostSock)
HandleErr(err, "not able to dial unix domain socket")
HandleErr(err, "unable to dial unix domain socket")
defer conn.Close()
decoder := gob.NewDecoder(conn)
var t models.L3AFALLHOSTDATA
err = decoder.Decode(&t)
HandleErr(err, "not able to decode")
HandleErr(err, "unable to decode")
machineHostname, err := os.Hostname()
HandleErr(err, "not able to fetch the hostname")
HandleErr(err, "unable to fetch the hostname")

l, err := restart.Getnetlistener(3, "stat_server")
HandleErr(err, "getting stat_server listener failed")
Expand Down
48 changes: 27 additions & 21 deletions restart/restart.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func Convert(ctx context.Context, t models.L3AFALLHOSTDATA, hostconfig *config.C
for _, r := range v {
f, err := DeserilazeProgram(ctx, r, hostconfig, k)
if err != nil {
log.Err(err).Msg("Deserilization failed for xdpingress")
log.Err(err).Msg("Deserialization failed for xdpingress")
return nil, err
}
l.PushBack(f)
Expand All @@ -223,7 +223,7 @@ func Convert(ctx context.Context, t models.L3AFALLHOSTDATA, hostconfig *config.C
for _, r := range v {
f, err := DeserilazeProgram(ctx, r, hostconfig, k)
if err != nil {
log.Err(err).Msg("Deserilization failed for tcingress")
log.Err(err).Msg("Deserialization failed for tcingress")
return nil, err
}
l.PushBack(f)
Expand All @@ -237,7 +237,7 @@ func Convert(ctx context.Context, t models.L3AFALLHOSTDATA, hostconfig *config.C
for _, r := range v {
f, err := DeserilazeProgram(ctx, r, hostconfig, k)
if err != nil {
log.Err(err).Msg("Deserilization failed for tcegress")
log.Err(err).Msg("Deserialization failed for tcegress")
return nil, err
}
l.PushBack(f)
Expand Down Expand Up @@ -277,7 +277,7 @@ func Getnetlistener(fd int, fname string) (*net.TCPListener, error) {
}
lf, e := l.(*net.TCPListener)
if !e {
return nil, fmt.Errorf("not able to covert to tcp listner")
return nil, fmt.Errorf("unable to covert to tcp listner")
}
file.Close()
return lf, nil
Expand Down Expand Up @@ -306,13 +306,16 @@ func GetNewVersion(urlpath string, oldVersion, newVersion string, conf *config.C
return nil
}
newVersionPath := conf.BasePath + "/" + newVersion
if !strings.HasPrefix(conf.BasePath, filepath.Clean(newVersionPath)+string(os.PathSeparator)) {
return fmt.Errorf("malicious input given to the restart api")
}
err := os.RemoveAll(newVersionPath)
if err != nil {
return fmt.Errorf("Error while deleting directory: %w", err)
return fmt.Errorf("error while deleting directory: %w", err)
}
err = os.MkdirAll(newVersionPath, 0750)
if err != nil {
return fmt.Errorf("Error while creating directory: %w", err)
return fmt.Errorf("error while creating directory: %w", err)
}
// now I need to download artifacts
URL, err := url.Parse(urlpath)
Expand All @@ -322,35 +325,35 @@ func GetNewVersion(urlpath string, oldVersion, newVersion string, conf *config.C
buf := &bytes.Buffer{}
err = bpfprogs.DownloadArtifact(URL, conf.HttpClientTimeout, buf)
if err != nil {
return fmt.Errorf("not able to download artifacts %w", err)
return fmt.Errorf("unable to download artifacts %w", err)
}
sp := strings.Split(urlpath, "/")
artifactName := sp[len(sp)-1]
err = bpfprogs.ExtractArtifact(artifactName, buf, newVersionPath)
dir := strings.Split(artifactName, ".")[0]
if err != nil {
return fmt.Errorf("not able to extract artifacts %w", err)
return fmt.Errorf("unable to extract artifacts %w", err)
}
// you need to store the old path for rollback purposes
// we will remove simlink
// we will remove symlink
err = RemoveSymlink(conf.BasePath + "/latest/l3afd")
if err != nil {
return fmt.Errorf("not able to remove simlink %w", err)
return fmt.Errorf("unable to remove symlink %w", err)
}
err = RemoveSymlink(conf.BasePath + "/latest/l3afd.cfg")
if err != nil {
return fmt.Errorf("not able to remove simlink %w", err)
return fmt.Errorf("unable to remove symlink %w", err)
}
// add new simlink
// add new symlink

err = AddSymlink(newVersionPath+"/"+dir+"/l3afd", conf.BasePath+"/latest/l3afd")
if err != nil {
return fmt.Errorf("not able to add simlink %w", err)
return fmt.Errorf("unable to add symlink %w", err)
}

err = AddSymlink(newVersionPath+"/"+dir+"/l3afd.cfg", conf.BasePath+"/latest/l3afd.cfg")
if err != nil {
return fmt.Errorf("not able to add simlink %w", err)
return fmt.Errorf("unable to add symlink %w", err)
}
// now we are good
return nil
Expand All @@ -361,31 +364,34 @@ func RollBackSymlink(oldCfgPath, oldBinPath string, oldVersion, newVersion strin
return nil
}
// you need to store the old path for rollback purposes
// we will remove simlink
// we will remove symlink
err := RemoveSymlink(conf.BasePath + "/latest/l3afd")
if err != nil {
return fmt.Errorf("not able to remove simlink %w", err)
return fmt.Errorf("unable to remove symlink %w", err)
}
err = RemoveSymlink(conf.BasePath + "/latest/l3afd.cfg")
if err != nil {
return fmt.Errorf("not able to remove simlink %w", err)
return fmt.Errorf("unable to remove symlink %w", err)
}
// add new simlink
// add new symlink

err = AddSymlink(oldBinPath, conf.BasePath+"/latest/l3afd")
if err != nil {
return fmt.Errorf("not able to add simlink %w", err)
return fmt.Errorf("unable to add symlink %w", err)
}

err = AddSymlink(oldCfgPath, conf.BasePath+"/latest/l3afd.cfg")
if err != nil {
return fmt.Errorf("not able to add simlink %w", err)
return fmt.Errorf("unable to add symlink %w", err)
}

newVersionPath := conf.BasePath + "/" + newVersion
if strings.Contains(newVersionPath, "..") {
return fmt.Errorf("malicious path")
}
err = os.RemoveAll(newVersionPath)
if err != nil {
return fmt.Errorf("Error while deleting directory: %w", err)
return fmt.Errorf("error while deleting directory: %w", err)
}
return nil
}
2 changes: 1 addition & 1 deletion stats/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func SetupMetrics(hostname, daemonName, metricsAddr string) {
}
listener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
log.Fatal().Err(err).Msgf("Not able to create net Listen")
log.Fatal().Err(err).Msgf("unable to create net Listen")
}
models.AllNetListeners.Store("stat_http", listener)
}
Expand Down

0 comments on commit 656c654

Please sign in to comment.