diff --git a/apis/handlers/restart_linux.go b/apis/handlers/restart_linux.go index 27d04a1b..c8d48866 100644 --- a/apis/handlers/restart_linux.go +++ b/apis/handlers/restart_linux.go @@ -11,12 +11,14 @@ import ( "net" "os" "os/exec" + "regexp" "strconv" "strings" "syscall" "time" "net/http" + "net/url" "github.com/rs/zerolog/log" @@ -71,6 +73,46 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc { statusCode = http.StatusInternalServerError return } + + match, _ := regexp.MatchString(`^v\d+\.\d+\.\d+$`, t.Version) + if !match { + mesg = fmt.Sprintf("version naming convention is wrong it will like vx.y.z") + log.Error().Msg(mesg) + statusCode = http.StatusInternalServerError + return + } + machineHostname, err := os.Hostname() + if err != nil { + mesg = fmt.Sprintf("failed to get os hostname") + log.Error().Msg(mesg) + statusCode = http.StatusInternalServerError + return + } + if machineHostname != t.HostName { + mesg = fmt.Sprintf("this api request is not for provided host") + log.Error().Msg(mesg) + statusCode = http.StatusInternalServerError + return + } + URL, err := url.Parse(t.ArtifactURL) + if err != nil { + mesg = fmt.Sprintf("url format is wrong") + log.Error().Msg(mesg) + statusCode = http.StatusInternalServerError + return + } + if URL.Scheme != models.HttpScheme && URL.Scheme != models.FileScheme && URL.Scheme != models.HttpsScheme { + mesg = fmt.Sprintf("currently only http,https,file is supported") + log.Error().Msg(mesg) + statusCode = http.StatusInternalServerError + return + } + if strings.Contains(t.ArtifactURL, "..") { + mesg = fmt.Sprintf("bad string") + log.Error().Msg(mesg) + statusCode = http.StatusInternalServerError + return + } defer func() { models.IsReadOnly = false }() @@ -88,14 +130,14 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc { err, oldCfgPath := restart.ReadSymlink(bpfcfg.HostConfig.BasePath + "/latest/l3afd.cfg") if err != nil { - mesg = fmt.Sprintf("failed read simlink: %v", err) + mesg = fmt.Sprintf("failed read symlink: %v", err) log.Error().Msg(mesg) statusCode = http.StatusInternalServerError return } err, oldBinPath := restart.ReadSymlink(bpfcfg.HostConfig.BasePath + "/latest/l3afd") if err != nil { - mesg = fmt.Sprintf("failed to read simlink: %v", err) + mesg = fmt.Sprintf("failed to read symlink: %v", err) log.Error().Msg(mesg) statusCode = http.StatusInternalServerError return @@ -179,22 +221,22 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc { err = cmd.Start() if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to start new instance %v", err) + mesg = mesg + fmt.Sprintf("unable to start new instance %v", err) // write a function a to do cleanup of other process if necessary err = cmd.Process.Kill() if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to kill the new instance %v", err) + mesg = mesg + fmt.Sprintf("unable to kill the new instance %v", err) } err = bpfcfg.StartAllUserProgramsAndProbes() if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to start all userprograms and probes: %v", err) + mesg = mesg + fmt.Sprintf("unable to start all userprograms and probes: %v", err) } err = pidfile.CreatePID(bpfcfg.HostConfig.PIDFilename) if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to create pid file: %v", err) + mesg = mesg + fmt.Sprintf("unable to create pid file: %v", err) } err = restart.RollBackSymlink(oldCfgPath, oldBinPath, oldVersion, t.Version, bpfcfg.HostConfig) if err != nil { @@ -242,17 +284,17 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc { err = cmd.Process.Kill() if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to kill the new instance %v", err) + mesg = mesg + fmt.Sprintf("unable to kill the new instance %v", err) } err = bpfcfg.StartAllUserProgramsAndProbes() if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to start all userprograms and probes: %v", err) + mesg = mesg + fmt.Sprintf("unable to start all userprograms and probes: %v", err) } err = pidfile.CreatePID(bpfcfg.HostConfig.PIDFilename) if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to create pid file: %v", err) + mesg = mesg + fmt.Sprintf("unable to create pid file: %v", err) } err = restart.RollBackSymlink(oldCfgPath, oldBinPath, oldVersion, t.Version, bpfcfg.HostConfig) if err != nil { @@ -273,17 +315,17 @@ func HandleRestart(bpfcfg *bpfprogs.NFConfigs) http.HandlerFunc { err = cmd.Process.Kill() if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to kill the new instance %v", err) + mesg = mesg + fmt.Sprintf("unable to kill the new instance %v", err) } err = bpfcfg.StartAllUserProgramsAndProbes() if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to start all userprograms and probes: %v", err) + mesg = mesg + fmt.Sprintf("unable to start all userprograms and probes: %v", err) } err = pidfile.CreatePID(bpfcfg.HostConfig.PIDFilename) if err != nil { log.Error().Msgf("%v", err) - mesg = mesg + fmt.Sprintf("not able to create pid file: %v", err) + mesg = mesg + fmt.Sprintf("unable to create pid file: %v", err) } err = restart.RollBackSymlink(oldCfgPath, oldBinPath, oldVersion, t.Version, bpfcfg.HostConfig) if err != nil { diff --git a/bpfprogs/bpf.go b/bpfprogs/bpf.go index 4f0ab31d..7bba8ba1 100644 --- a/bpfprogs/bpf.go +++ b/bpfprogs/bpf.go @@ -630,7 +630,7 @@ func (b *BPF) GetArtifacts(conf *config.Config) error { tempDir := filepath.Join(conf.BPFDir, b.Program.Name, b.Program.Version) err = ExtractArtifact(b.Program.Artifact, buf, tempDir) if err != nil { - return fmt.Errorf("not able to extract artifact %w", err) + return fmt.Errorf("unable to extract artifact %w", err) } newDir := strings.Split(b.Program.Artifact, ".") b.FilePath = filepath.Join(tempDir, newDir[0]) @@ -871,7 +871,7 @@ func (b *BPF) MonitorMaps(ifaceName string, intervals int) error { _, ok := b.MetricsBpfMaps[mapKey] if !ok { if err := b.AddMetricsBPFMap(element.Name, element.Aggregator, element.Key, intervals); err != nil { - return fmt.Errorf("not able to fetch map %s key %d aggregator %s : %w", element.Name, element.Key, element.Aggregator, err) + return fmt.Errorf("unable to fetch map %s key %d aggregator %s : %w", element.Name, element.Key, element.Aggregator, err) } } bpfMap := b.MetricsBpfMaps[mapKey] diff --git a/bpfprogs/bpfdebug.go b/bpfprogs/bpfdebug.go index eeb5d3d9..1384ccfb 100644 --- a/bpfprogs/bpfdebug.go +++ b/bpfprogs/bpfdebug.go @@ -28,7 +28,7 @@ func SetupBPFDebug(ebpfChainDebugAddr string, BPFConfigs *NFConfigs) { } listener, err := net.ListenTCP("tcp", tcpAddr) if err != nil { - log.Fatal().Err(err).Msgf("Not able to create net Listen") + log.Fatal().Err(err).Msgf("unable to create net Listen") } models.AllNetListeners.Store("debug_http", listener) } diff --git a/bpfprogs/nfconfig.go b/bpfprogs/nfconfig.go index aaba97c1..199e79db 100644 --- a/bpfprogs/nfconfig.go +++ b/bpfprogs/nfconfig.go @@ -1623,7 +1623,7 @@ func (c *NFConfigs) StartAllUserProgramsAndProbes() error { prg := b.ProgMapCollection b.ProgMapCollection = nil if err := b.LoadBPFProgram(iface); err != nil { - return fmt.Errorf("not able to load probes %w", err) + return fmt.Errorf("unable to load probes %w", err) } b.Program.EntryFunctionName = ef if b.ProgMapCollection != nil { @@ -1667,7 +1667,7 @@ func (c *NFConfigs) StartAllUserProgramsAndProbes() error { prg := b.ProgMapCollection b.ProgMapCollection = nil if err := b.LoadBPFProgram(iface); err != nil { - return fmt.Errorf("not able to load probes %w", err) + return fmt.Errorf("unable to load probes %w", err) } b.Program.EntryFunctionName = ef if b.ProgMapCollection != nil { @@ -1711,7 +1711,7 @@ func (c *NFConfigs) StartAllUserProgramsAndProbes() error { prg := b.ProgMapCollection b.ProgMapCollection = nil if err := b.LoadBPFProgram(iface); err != nil { - return fmt.Errorf("not able to load probes %w", err) + return fmt.Errorf("unable to load probes %w", err) } b.Program.EntryFunctionName = ef if b.ProgMapCollection != nil { diff --git a/main.go b/main.go index 23b37f96..cf63e56c 100644 --- a/main.go +++ b/main.go @@ -263,14 +263,14 @@ func setupForRestart(ctx context.Context, conf *config.Config) error { models.IsReadOnly = true // Now you need to write client side code conn, err := net.Dial("unix", models.HostSock) - HandleErr(err, "not able to dial unix domain socket") + HandleErr(err, "unable to dial unix domain socket") defer conn.Close() decoder := gob.NewDecoder(conn) var t models.L3AFALLHOSTDATA err = decoder.Decode(&t) - HandleErr(err, "not able to decode") + HandleErr(err, "unable to decode") machineHostname, err := os.Hostname() - HandleErr(err, "not able to fetch the hostname") + HandleErr(err, "unable to fetch the hostname") l, err := restart.Getnetlistener(3, "stat_server") HandleErr(err, "getting stat_server listener failed") diff --git a/restart/restart.go b/restart/restart.go index 9d86f1b2..2044b84d 100644 --- a/restart/restart.go +++ b/restart/restart.go @@ -277,7 +277,7 @@ func Getnetlistener(fd int, fname string) (*net.TCPListener, error) { } lf, e := l.(*net.TCPListener) if !e { - return nil, fmt.Errorf("not able to covert to tcp listner") + return nil, fmt.Errorf("unable to covert to tcp listner") } file.Close() return lf, nil @@ -306,6 +306,9 @@ func GetNewVersion(urlpath string, oldVersion, newVersion string, conf *config.C return nil } newVersionPath := conf.BasePath + "/" + newVersion + if !strings.HasPrefix(conf.BasePath, filepath.Clean(newVersionPath)+string(os.PathSeparator)) { + return fmt.Errorf("malicious input given to the restart api") + } err := os.RemoveAll(newVersionPath) if err != nil { return fmt.Errorf("Error while deleting directory: %w", err) @@ -319,38 +322,44 @@ func GetNewVersion(urlpath string, oldVersion, newVersion string, conf *config.C if err != nil { return fmt.Errorf("unknown url format : %w", err) } + if URL.Scheme == models.HttpScheme || URL.Scheme == models.HttpsScheme { + URL, err = url.ParseRequestURI(urlpath) + if err != nil { + return fmt.Errorf("unknown http/https url format : %w", err) + } + } buf := &bytes.Buffer{} err = bpfprogs.DownloadArtifact(URL, conf.HttpClientTimeout, buf) if err != nil { - return fmt.Errorf("not able to download artifacts %w", err) + return fmt.Errorf("unable to download artifacts %w", err) } sp := strings.Split(urlpath, "/") artifactName := sp[len(sp)-1] err = bpfprogs.ExtractArtifact(artifactName, buf, newVersionPath) dir := strings.Split(artifactName, ".")[0] if err != nil { - return fmt.Errorf("not able to extract artifacts %w", err) + return fmt.Errorf("unable to extract artifacts %w", err) } // you need to store the old path for rollback purposes - // we will remove simlink + // we will remove symlink err = RemoveSymlink(conf.BasePath + "/latest/l3afd") if err != nil { - return fmt.Errorf("not able to remove simlink %w", err) + return fmt.Errorf("unable to remove symlink %w", err) } err = RemoveSymlink(conf.BasePath + "/latest/l3afd.cfg") if err != nil { - return fmt.Errorf("not able to remove simlink %w", err) + return fmt.Errorf("unable to remove symlink %w", err) } - // add new simlink + // add new symlink err = AddSymlink(newVersionPath+"/"+dir+"/l3afd", conf.BasePath+"/latest/l3afd") if err != nil { - return fmt.Errorf("not able to add simlink %w", err) + return fmt.Errorf("unable to add symlink %w", err) } err = AddSymlink(newVersionPath+"/"+dir+"/l3afd.cfg", conf.BasePath+"/latest/l3afd.cfg") if err != nil { - return fmt.Errorf("not able to add simlink %w", err) + return fmt.Errorf("unable to add symlink %w", err) } // now we are good return nil @@ -361,28 +370,31 @@ func RollBackSymlink(oldCfgPath, oldBinPath string, oldVersion, newVersion strin return nil } // you need to store the old path for rollback purposes - // we will remove simlink + // we will remove symlink err := RemoveSymlink(conf.BasePath + "/latest/l3afd") if err != nil { - return fmt.Errorf("not able to remove simlink %w", err) + return fmt.Errorf("unable to remove symlink %w", err) } err = RemoveSymlink(conf.BasePath + "/latest/l3afd.cfg") if err != nil { - return fmt.Errorf("not able to remove simlink %w", err) + return fmt.Errorf("unable to remove symlink %w", err) } - // add new simlink + // add new symlink err = AddSymlink(oldBinPath, conf.BasePath+"/latest/l3afd") if err != nil { - return fmt.Errorf("not able to add simlink %w", err) + return fmt.Errorf("unable to add symlink %w", err) } err = AddSymlink(oldCfgPath, conf.BasePath+"/latest/l3afd.cfg") if err != nil { - return fmt.Errorf("not able to add simlink %w", err) + return fmt.Errorf("unable to add symlink %w", err) } newVersionPath := conf.BasePath + "/" + newVersion + if !strings.HasPrefix(conf.BasePath, filepath.Clean(newVersionPath)+string(os.PathSeparator)) { + return fmt.Errorf("malicious input given to the restart api") + } err = os.RemoveAll(newVersionPath) if err != nil { return fmt.Errorf("Error while deleting directory: %w", err) diff --git a/stats/metrics.go b/stats/metrics.go index 99ac560b..d49ecec4 100644 --- a/stats/metrics.go +++ b/stats/metrics.go @@ -130,7 +130,7 @@ func SetupMetrics(hostname, daemonName, metricsAddr string) { } listener, err := net.ListenTCP("tcp", tcpAddr) if err != nil { - log.Fatal().Err(err).Msgf("Not able to create net Listen") + log.Fatal().Err(err).Msgf("unable to create net Listen") } models.AllNetListeners.Store("stat_http", listener) }