diff --git a/easyssh.go b/easyssh.go index 6fd908c..15e0740 100644 --- a/easyssh.go +++ b/easyssh.go @@ -6,6 +6,7 @@ package easyssh import ( "bufio" + "context" "errors" "fmt" "io" @@ -357,7 +358,8 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<- if len(timeout) > 0 { executeTimeout = timeout[0] } - timeoutChan := time.After(executeTimeout) + ctxTimeout, cancel := context.WithTimeout(context.Background(), executeTimeout) + defer cancel() res := make(chan struct{}, 1) var resWg sync.WaitGroup resWg.Add(2) @@ -398,8 +400,8 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<- case <-res: errChan <- session.Wait() doneChan <- true - case <-timeoutChan: - errChan <- fmt.Errorf("Run Command Timeout") + case <-ctxTimeout.Done(): + errChan <- fmt.Errorf("Run Command Timeout: %v", ctxTimeout.Err()) doneChan <- false } }(stdoutScanner, stderrScanner, stdoutChan, stderrChan, doneChan, errChan) diff --git a/easyssh_test.go b/easyssh_test.go index 18ddb1e..5381d93 100644 --- a/easyssh_test.go +++ b/easyssh_test.go @@ -1,6 +1,7 @@ package easyssh import ( + "context" "os" "os/user" "path" @@ -20,7 +21,6 @@ func getHostPublicKeyFile(keypath string) (ssh.PublicKey, error) { } pubkey, _, _, _, err = ssh.ParseAuthorizedKey(buf) - if err != nil { return nil, err } @@ -169,7 +169,7 @@ func TestRunCommand(t *testing.T) { assert.Equal(t, "", errStr) assert.False(t, isTimeout) assert.Error(t, err) - assert.Equal(t, "Run Command Timeout", err.Error()) + assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error()) // test exit code outStr, errStr, isTimeout, err = ssh.Run("exit 1") @@ -496,3 +496,19 @@ func TestSudoCommand(t *testing.T) { assert.True(t, isTimeout) assert.NoError(t, err) } + +func TestCommandTimeout(t *testing.T) { + ssh := &MakeConfig{ + Server: "localhost", + User: "root", + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + } + + outStr, errStr, isTimeout, err := ssh.Run("whoami; sleep 2", 1*time.Second) + assert.Equal(t, "root\n", outStr) + assert.Equal(t, "", errStr) + assert.False(t, isTimeout) + assert.NotNil(t, err) + assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error()) +}