Skip to content

Commit

Permalink
chore: Fix flaky PTY tests (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
irvinlim authored Feb 10, 2025
1 parent e432f51 commit 55a524e
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 39 deletions.
17 changes: 13 additions & 4 deletions pkg/cli/cmd/cmd_list_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,31 +516,40 @@ func TestListJobCommand(t *testing.T) {
},
{
Name: "watch jobs belonging to jobconfig with --for and --finished",
Args: []string{"list", "job", "--for", adhocJobConfig.Name},
Args: []string{"list", "job", "-w", "--for", adhocJobConfig.Name},
Fixtures: []runtime.Object{
jobRunning,
jobFinished,
jobQueued,
adhocJobConfig,
},
Procedure: func(t *testing.T, rc runtimetesting.RunContext) {
// Expect that headers and rows are printed.
rc.Console.ExpectString("NAME")
rc.Console.ExpectString(jobRunning.Name)
rc.Console.ExpectString(string(jobRunning.Status.Phase))

// Update the Job's state.
newJob := jobRunning.DeepCopy()
newJob.Status.Phase = v1alpha1.JobSucceeded
newJob.Status.State = v1alpha1.JobStateFinished
newJob.Status.Condition = v1alpha1.JobCondition{
Finished: &v1alpha1.JobConditionFinished{
FinishTimestamp: testutils.Mkmtime(taskFinishTime),
},
}
_, err := rc.CtrlContext.Clientsets().Furiko().ExecutionV1alpha1().Jobs(newJob.Namespace).Update(rc.Context, newJob, metav1.UpdateOptions{})
assert.NoError(t, err)

// Expect that headers and rows are printed.
rc.Console.ExpectString("NAME")
// Expect that we got the update.
rc.Console.ExpectString(newJob.Name)
rc.Console.ExpectString(string(newJob.Status.Phase))

// Cancel watch.
rc.Cancel()
},
Stdout: runtimetesting.Output{
NumLines: pointer.Int(2),
NumLines: pointer.Int(3),
},
WantActions: runtimetesting.CombinedActions{
Ignore: true,
Expand Down
35 changes: 30 additions & 5 deletions pkg/cli/console/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,28 @@
package console

import (
"bytes"
"io"
"testing"

"github.com/Netflix/go-expect"
"github.com/creack/pty"
"github.com/hinshun/vt10x"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

type Console struct {
*expect.Console
T *testing.T

// The bytes that were read during Expect.
read bytes.Buffer
}

// NewConsole returns a Console for testing PTY handling.
// All PTY output will be piped to w.
func NewConsole(w io.Writer) (*Console, error) {
func NewConsole(t *testing.T, w io.Writer) (*Console, error) {
ptty, tty, err := pty.Open()
if err != nil {
return nil, errors.Wrapf(err, "failed to open pty")
Expand All @@ -48,6 +55,7 @@ func NewConsole(w io.Writer) (*Console, error) {
}

c := &Console{
T: t,
Console: console,
}

Expand All @@ -68,10 +76,27 @@ func (c *Console) Run(f func(c *Console)) <-chan struct{} {
return done
}

func (c *Console) ExpectString(s string) {
_, _ = c.Console.ExpectString(s)
// ExpectString advances the PTY buffer until the given string is found.
// If not, a test error will be thrown, and returns false.
func (c *Console) ExpectString(s string) bool {
got, err := c.Console.Expect(expect.String(s))
c.read.WriteString(got)
return assert.NoError(c.T, err, `did not find expected string: "%v", got "%v"`, s, got)
}

// SendLine writes to the PTY buffer.
// Blocks until the line is written.
func (c *Console) SendLine(s string) bool {
_, err := c.Console.SendLine(s)
return assert.NoError(c.T, err, `cannot send line: "%v"`, s)
}

func (c *Console) SendLine(s string) {
_, _ = c.Console.SendLine(s)
// Close the TTY, unblocking all Expect and ExpectEOF calls.
func (c *Console) Close() error {
err := c.Console.Close()
output := c.read.String()
if len(output) > 0 {
c.T.Logf("Console output:\n%s\n\n", output)
}
return err
}
18 changes: 9 additions & 9 deletions pkg/cli/prompt/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestNewBoolPrompt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pt, ps, err := NewPromptTest()
pt, ps, err := NewPromptTest(t)
if err != nil {
t.Fatalf("cannot initialize test: %v", err)
}
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestNewStringPrompt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pt, ps, err := NewPromptTest()
pt, ps, err := NewPromptTest(t)
if err != nil {
t.Fatalf("cannot initialize test: %v", err)
}
Expand Down Expand Up @@ -324,7 +324,7 @@ func TestNewSelectPrompt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pt, ps, err := NewPromptTest()
pt, ps, err := NewPromptTest(t)
if err != nil {
t.Fatalf("cannot initialize test: %v", err)
}
Expand Down Expand Up @@ -418,7 +418,7 @@ func TestNewMultiPrompt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pt, ps, err := NewPromptTest()
pt, ps, err := NewPromptTest(t)
if err != nil {
t.Fatalf("cannot initialize test: %v", err)
}
Expand Down Expand Up @@ -472,7 +472,7 @@ func TestNewDatePrompt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pt, ps, err := NewPromptTest()
pt, ps, err := NewPromptTest(t)
if err != nil {
t.Fatalf("cannot initialize test: %v", err)
}
Expand All @@ -492,9 +492,9 @@ type PromptTest struct {
console *console.Console
}

func NewPromptTest() (*PromptTest, *streams.Streams, error) {
func NewPromptTest(t *testing.T) (*PromptTest, *streams.Streams, error) {
iostreams, _, _, _ := genericclioptions.NewTestIOStreams()
c, err := console.NewConsole(iostreams.Out)
c, err := console.NewConsole(t, iostreams.Out)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to init pseudo TTY")
}
Expand Down Expand Up @@ -546,6 +546,6 @@ func (pt *PromptTest) Run(t *testing.T, p prompt.Prompt, procedure func(c *conso
return resp.retval, resp.err
}

func (pt *PromptTest) Close() {
_ = pt.console.Close()
func (pt *PromptTest) Close() error {
return pt.console.Close()
}
85 changes: 64 additions & 21 deletions pkg/runtime/testing/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
package testing

import (
"bytes"
"context"
"fmt"
"regexp"
"strings"
"sync"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/cli-runtime/pkg/genericclioptions"
Expand All @@ -39,6 +42,10 @@ import (
"github.com/furiko-io/furiko/pkg/utils/testutils"
)

const (
testTimeout = time.Second * 5
)

// RunCommandTests executes all CommandTest cases.
func RunCommandTests(t *testing.T, cases []CommandTest) {
for _, tt := range cases {
Expand Down Expand Up @@ -101,7 +108,7 @@ type Output struct {
// If specified, expects the output to match the specified regexp.
Matches *regexp.Regexp

// If specified, expects the output to match all of the given regular expressions.
// If specified, expects the output to match all the given regular expressions.
MatchesAll []*regexp.Regexp

// If specified, expects that the output to NOT contain the given string.
Expand Down Expand Up @@ -129,27 +136,32 @@ type RunContext struct {
}

func (c *CommandTest) Run(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

ctrlContext := mock.NewContext()
common.SetCtrlContext(ctrlContext)
iostreams, _, stdout, stderr := genericclioptions.NewTestIOStreams()

done := make(chan struct{})
go func() {
c.run(ctx, t)
c.run(ctx, ctrlContext, t, iostreams)
close(done)
}()

select {
case <-done:
c.verify(t, ctrlContext, stdout, stderr)
return
case <-ctx.Done():
t.Errorf("test did not complete: %v", ctx.Err())
t.Errorf("test did not complete after %v: %v", testTimeout, ctx.Err())
t.Logf("command stdout =\n%s", stdout.String())
t.Logf("command stderr =\n%s", stderr.String())
}
}

func (c *CommandTest) run(ctx context.Context, t *testing.T) {
func (c *CommandTest) run(ctx context.Context, ctrlContext *mock.Context, t *testing.T, iostreams genericclioptions.IOStreams) bool {
// Override the shared context.
ctrlContext := mock.NewContext()
common.SetCtrlContext(ctrlContext)
client := ctrlContext.MockClientsets()
assert.NoError(t, InitFixtures(ctx, client, c.Fixtures))
client.ClearActions()
Expand All @@ -162,10 +174,11 @@ func (c *CommandTest) run(ctx context.Context, t *testing.T) {
ktime.Clock = fakeclock.NewFakeClock(now)

// Run command with I/O.
iostreams, _, stdout, stderr := genericclioptions.NewTestIOStreams()
if c.runCommand(ctx, t, ctrlContext, iostreams) {
return
}
return c.runCommand(ctx, t, ctrlContext, iostreams)
}

func (c *CommandTest) verify(t *testing.T, ctrlContext *mock.Context, stdout, stderr *bytes.Buffer) {
client := ctrlContext.MockClientsets()

// Ensure that output matches.
c.checkOutput(t, "stdout", stdout.String(), c.Stdout)
Expand All @@ -176,27 +189,26 @@ func (c *CommandTest) run(ctx context.Context, t *testing.T) {
}

// runCommand will execute the command, setting up all I/O streams and blocking
// until the streams are done.
// until the streams are done. Returns true if an error was encountered.
//
// Reference:
// https://github.com/AlecAivazis/survey/blob/93657ef69381dd1ffc7a4a9cfe5a2aefff4ca4ad/survey_posix_test.go#L15
func (c *CommandTest) runCommand(ctx context.Context, t *testing.T, ctrlContext *mock.Context, iostreams genericclioptions.IOStreams) bool {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

console, err := console.NewConsole(iostreams.Out)
console, err := console.NewConsole(t, iostreams.Out)
if err != nil {
t.Fatalf("failed to create console: %v", err)
}
defer console.Close()

// Prepare root command.
command := cmd.NewRootCommand(streams.NewTTYStreams(console.Tty()))
command.SetArgs(c.Args)

var done <-chan struct{}

// Run procedure if specified.
// Run procedure in the background.
if c.Procedure != nil {
done = c.runProcedure(t, c.Procedure, RunContext{
Console: console,
Expand All @@ -208,6 +220,41 @@ func (c *CommandTest) runCommand(ctx context.Context, t *testing.T, ctrlContext
done = console.Run(c.Stdin.Procedure)
}

// Only close the PTY once.
closeConsole := sync.OnceFunc(func() {
if err := console.Close(); err != nil {
t.Errorf("cannot close PTY: %v", err)
}
})

// Close the console if context had deadline exceeded.
// Note that the context could possibly be canceled by tests in order to terminate the command execution.
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("Context was canceled: %v", ctx.Err().Error())
closeConsole()
}
case <-done:
// Already closed.
}
}()

defer func() {
// Always make sure to explicitly close the PTY.
closeConsole()

// Wait for the procedure to complete and EOF to be read.
<-done

// Wait for .
wg.Wait()
}()

// Execute command.
if testutils.WantError(t, c.WantError, command.ExecuteContext(ctx), fmt.Sprintf("Run error with args: %v", c.Args)) {
return true
Expand All @@ -216,15 +263,11 @@ func (c *CommandTest) runCommand(ctx context.Context, t *testing.T, ctrlContext
// TODO(irvinlim): We need a sleep here otherwise tests will be flaky
time.Sleep(time.Millisecond * 5)

// Wait for tty to be closed.
if err := console.Tty().Close(); err != nil {
t.Errorf("error closing Tty: %v", err)
}
<-done

return false
}

// runProcedure starts the procedure in the background, and returns a channel
// that will be closed once the PTY output is closed.
func (c *CommandTest) runProcedure(t *testing.T, procedure func(t *testing.T, rc RunContext), rc RunContext) <-chan struct{} {
done := make(chan struct{})
go func() {
Expand Down

0 comments on commit 55a524e

Please sign in to comment.