From b6a0c308527980bef335a5b8621ba4474329a9a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marta=20G=C3=B3mez=20Mac=C3=ADas?= Date: Fri, 10 Mar 2023 17:03:25 +0100 Subject: [PATCH] feat(scan): Add a flag to wait for analysis completion (#71) --- cmd/scan.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 110 insertions(+), 14 deletions(-) diff --git a/cmd/scan.go b/cmd/scan.go index d7b187a..805e737 100644 --- a/cmd/scan.go +++ b/cmd/scan.go @@ -14,8 +14,11 @@ package cmd import ( + "context" "fmt" "os" + "strings" + "time" "github.com/VirusTotal/vt-cli/utils" vt "github.com/VirusTotal/vt-go" @@ -24,9 +27,56 @@ import ( "github.com/spf13/viper" ) +const ( + // PollFrequency defines the interval in which requests are sent to the + // VT API to check if the analysis is completed. + PollFrequency = 10 * time.Second + // TimeoutLimit defines the maximum amount of minutes to wait for an + // analysis' results. + TimeoutLimit = 10 * time.Minute +) + +// waitForAnalysisResults calls every PollFrequency seconds to the VT API and +// checks whether an analysis is completed or not. When the analysis is completed +// it is returned. +func waitForAnalysisResults(cli *utils.APIClient, analysisId string, ds *utils.DoerState) (*vt.Object, error) { + ds.Progress = "Waiting for analysis completion..." + ticker := time.NewTicker(PollFrequency) + defer ticker.Stop() + ctx, cancel := context.WithTimeout(context.Background(), TimeoutLimit) + defer cancel() + i := 1 + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + ds.Progress = fmt.Sprintf("Waiting for analysis completion...%s", strings.Repeat(".", i)) + i++ + if obj, err := cli.GetObject(vt.URL(fmt.Sprintf("analyses/%s", analysisId))); err != nil { + // If the API returned an error 503 (transient error) retry; otherwise just return + // the error to the user. + if e, ok := err.(*vt.Error); !ok || e.Code != "TransientError" { + ds.Progress = "" + return nil, fmt.Errorf("error retrieving analysis result: %v", err) + } + } else if status, _ := obj.Get("status"); status == "completed" { + ds.Progress = "" + // Request the full object report and return it instead of just + // the analysis results. + return cli.GetObject(vt.URL(fmt.Sprintf("analyses/%s/item", analysisId))) + } + } + } +} + type fileScanner struct { - scanner *vt.FileScanner - showInVT bool + scanner *vt.FileScanner + cli *utils.APIClient + printer *utils.Printer + showInVT bool + waitForCompletion bool } func (s *fileScanner) Do(path interface{}, ds *utils.DoerState) string { @@ -46,22 +96,31 @@ func (s *fileScanner) Do(path interface{}, ds *utils.DoerState) string { f, err := os.Open(path.(string)) if err != nil { - return fmt.Sprintf("%s", err) + return err.Error() } defer f.Close() analysis, err := s.scanner.ScanFile(f, progressCh) if err != nil { - return fmt.Sprintf("%s", err) + return err.Error() } if s.showInVT { - // Return the analysis URL in VT so users can visit it + // Return the analysis URL in VT so users can visit it. return fmt.Sprintf( "%s https://www.virustotal.com/gui/file-analysis/%s", path.(string), analysis.ID()) } + if s.waitForCompletion { + analysisResult, err := waitForAnalysisResults(s.cli, analysis.ID(), ds) + if err != nil { + return err.Error() + } + s.printer.PrintObject(analysisResult) + return "" + } + return fmt.Sprintf("%s %s", path.(string), analysis.ID()) } @@ -70,7 +129,8 @@ var scanFileCmdHelp = `Scan one or more files. This command receives one or more file paths and uploads them to VirusTotal for scanning. It returns the file paths followed by their corresponding analysis IDs. You can use the "vt analysis" command for retrieving information about the -analyses. +analyses or you can use the --wait flag to see the results when the +analysis is completed. If the command receives a single hypen (-) the file paths are read from the standard input, one per line. @@ -105,9 +165,16 @@ func NewScanFileCmd() *cobra.Command { if err != nil { return err } + p, err := NewPrinter(cmd) + if err != nil { + return err + } s := &fileScanner{ - scanner: client.NewFileScanner(), - showInVT: viper.GetBool("open")} + scanner: client.NewFileScanner(), + showInVT: viper.GetBool("open"), + waitForCompletion: viper.GetBool("wait"), + printer: p, + cli: client} c.DoWithStringsFromReader(s, argReader) return nil }, @@ -115,20 +182,25 @@ func NewScanFileCmd() *cobra.Command { addThreadsFlag(cmd.Flags()) addOpenInVTFlag(cmd.Flags()) + addWaitForCompletionFlag(cmd.Flags()) + addIncludeExcludeFlags(cmd.Flags()) cmd.MarkZshCompPositionalArgumentFile(1) return cmd } type urlScanner struct { - scanner *vt.URLScanner - showInVT bool + scanner *vt.URLScanner + cli *utils.APIClient + printer *utils.Printer + showInVT bool + waitForCompletion bool } func (s *urlScanner) Do(url interface{}, ds *utils.DoerState) string { analysis, err := s.scanner.Scan(url.(string)) if err != nil { - return fmt.Sprintf("%s", err) + return err.Error() } if s.showInVT { @@ -136,6 +208,15 @@ func (s *urlScanner) Do(url interface{}, ds *utils.DoerState) string { "%s https://www.virustotal.com/gui/url-analysis/%s", url, analysis.ID()) } + if s.waitForCompletion { + analysisResult, err := waitForAnalysisResults(s.cli, analysis.ID(), ds) + if err != nil { + return err.Error() + } + s.printer.PrintObject(analysisResult) + return "" + } + return fmt.Sprintf("%s %s", url, analysis.ID()) } @@ -143,7 +224,8 @@ var scanURLCmdHelp = `Scan one or more URLs. This command receives one or more URLs and scan them. It returns the URLs followed by their corresponding analysis IDs. You can use the "vt analysis" command for -retrieving information about the analyses. +retrieving information about the analyses or you can use the --wait +flag to see the results when the analysis is completed. If the command receives a single hypen (-) the URLs are read from the standard input, one per line.` @@ -174,9 +256,16 @@ func NewScanURLCmd() *cobra.Command { if err != nil { return err } + p, err := NewPrinter(cmd) + if err != nil { + return err + } s := &urlScanner{ - scanner: client.NewURLScanner(), - showInVT: viper.GetBool("open")} + scanner: client.NewURLScanner(), + showInVT: viper.GetBool("open"), + waitForCompletion: viper.GetBool("wait"), + printer: p, + cli: client} c.DoWithStringsFromReader(s, argReader) return nil }, @@ -184,6 +273,7 @@ func NewScanURLCmd() *cobra.Command { addThreadsFlag(cmd.Flags()) addOpenInVTFlag(cmd.Flags()) + addWaitForCompletionFlag(cmd.Flags()) return cmd } @@ -212,3 +302,9 @@ func addOpenInVTFlag(flags *pflag.FlagSet) { "open", "o", false, "Return an URL to see the analysis report at the VirusTotal web GUI") } + +func addWaitForCompletionFlag(flags *pflag.FlagSet) { + flags.BoolP( + "wait", "w", false, + "Wait until the analysis is completed and show the analysis results") +}