Skip to content

Commit

Permalink
Support keyboard interrupt (#81)
Browse files Browse the repository at this point in the history
Co-authored-by: hiroebe <hiroebe41@gmail.com>
  • Loading branch information
hiroebe and hiroebe authored Oct 28, 2022
1 parent 211e8c9 commit f82212c
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 197 deletions.
60 changes: 31 additions & 29 deletions conn/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package conn

import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
Expand All @@ -11,77 +12,78 @@ import (

type Connection struct{}

func (conn *Connection) CheckRepos(hostname string, repoNames []string) error {
func (conn *Connection) CheckRepos(ctx context.Context, hostname string, repoNames []string) error {
for _, name := range repoNames {
args := []string{
"api",
"--hostname", hostname,
"repos/" + name,
"--silent",
}
if _, err := run("gh", args); err != nil {
if _, err := run(ctx, "gh", args); err != nil {
return err
}
}
return nil
}

func (conn *Connection) GetRemoteNames() (string, error) {
func (conn *Connection) GetRemoteNames(ctx context.Context) (string, error) {
args := []string{
"remote", "-v",
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) GetSshConfig(name string) (string, error) {
func (conn *Connection) GetSshConfig(ctx context.Context, name string) (string, error) {
args := []string{
"-T", "-G", name,
}
return run("ssh", args)
return run(ctx, "ssh", args)
}

func (conn *Connection) GetRepoNames(hostname string, repoName string) (string, error) {
func (conn *Connection) GetRepoNames(ctx context.Context, hostname string, repoName string) (string, error) {
args := []string{
"repo", "view", hostname + "/" + repoName,
"--json", "owner",
"--json", "name",
"--json", "parent",
"--json", "defaultBranchRef",
}
return run("gh", args)
return run(ctx, "gh", args)
}

func (conn *Connection) GetBranchNames() (string, error) {
func (conn *Connection) GetBranchNames(ctx context.Context) (string, error) {
args := []string{
"branch", "-v", "--no-abbrev",
"--format=%(HEAD):%(refname:lstrip=2):%(objectname)",
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) GetMergedBranchNames(remoteName string, branchName string) (string, error) {
func (conn *Connection) GetMergedBranchNames(ctx context.Context, remoteName string, branchName string) (string, error) {
args := []string{
"branch", "--merged", fmt.Sprintf("%s/%s", remoteName, branchName),
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) GetLog(branchName string) (string, error) {
func (conn *Connection) GetLog(ctx context.Context, branchName string) (string, error) {
args := []string{
"log", "--first-parent", "--max-count=30", "--format=%H", branchName, "--",
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) GetAssociatedRefNames(oid string) (string, error) {
func (conn *Connection) GetAssociatedRefNames(ctx context.Context, oid string) (string, error) {
args := []string{
"branch", "--all", "--format=%(refname)",
"--contains", oid,
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) GetPullRequests(
ctx context.Context,
hostname string, repoNames []string, queryHashes string) (string, error) {
args := []string{
"api", "graphql",
Expand Down Expand Up @@ -114,43 +116,43 @@ func (conn *Connection) GetPullRequests(
queryHashes,
),
}
return run("gh", args)
return run(ctx, "gh", args)
}

func (conn *Connection) GetUncommittedChanges() (string, error) {
func (conn *Connection) GetUncommittedChanges(ctx context.Context) (string, error) {
args := []string{
"status", "--short",
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) GetConfig(key string) (string, error) {
func (conn *Connection) GetConfig(ctx context.Context, key string) (string, error) {
args := []string{
"config", "--get", key,
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) CheckoutBranch(branchName string) (string, error) {
func (conn *Connection) CheckoutBranch(ctx context.Context, branchName string) (string, error) {
args := []string{
"checkout", "--quiet", branchName,
}
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) DeleteBranches(branchNames []string) (string, error) {
func (conn *Connection) DeleteBranches(ctx context.Context, branchNames []string) (string, error) {
args := append([]string{
"branch", "-D"},
branchNames...,
)
return run("git", args)
return run(ctx, "git", args)
}

func (conn *Connection) PruneRemoteBranches(remoteName string) (string, error) {
func (conn *Connection) PruneRemoteBranches(ctx context.Context, remoteName string) (string, error) {
args := []string{
"remote", "prune", remoteName,
}
return run("git", args)
return run(ctx, "git", args)
}

func getQueryRepos(repoNames []string) string {
Expand All @@ -161,14 +163,14 @@ func getQueryRepos(repoNames []string) string {
return repos.String()
}

func run(name string, args []string) (string, error) {
func run(ctx context.Context, name string, args []string) (string, error) {
cmdPath, err := safeexec.LookPath(name)
if err != nil {
return "", err
}

var stdout bytes.Buffer
cmd := exec.Command(cmdPath, args...)
cmd := exec.CommandContext(ctx, cmdPath, args...)
cmd.Stdout = &stdout

err = cmd.Run()
Expand Down
17 changes: 9 additions & 8 deletions conn/command_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package conn

import (
"context"
"os"
"path/filepath"
"testing"
Expand All @@ -17,15 +18,15 @@ func Test_RepoBasic(t *testing.T) {
stub := &Stub{nil, t}

t.Run("GetRemoteNames", func(t *testing.T) {
actual, _ := conn.GetRemoteNames()
actual, _ := conn.GetRemoteNames(context.Background())
assert.Equal(t,
stub.readFile("git", "remote", "origin"),
actual,
)
})

t.Run("GetBranchNames", func(t *testing.T) {
actual, _ := conn.GetBranchNames()
actual, _ := conn.GetBranchNames(context.Background())
assert.Equal(t,
stub.readFile("git", "branch", "@main_issue1"),
actual,
Expand All @@ -35,15 +36,15 @@ func Test_RepoBasic(t *testing.T) {
t.Run("GetLog", func(t *testing.T) {

t.Run("main", func(t *testing.T) {
actual, _ := conn.GetLog("main")
actual, _ := conn.GetLog(context.Background(), "main")
assert.Equal(t,
stub.readFile("git", "log", "main"),
actual,
)
})

t.Run("issue1", func(t *testing.T) {
actual, _ := conn.GetLog("issue1")
actual, _ := conn.GetLog(context.Background(), "issue1")
assert.Equal(t,
stub.readFile("git", "log", "issue1"),
actual,
Expand All @@ -54,15 +55,15 @@ func Test_RepoBasic(t *testing.T) {
t.Run("GetAssociatedRefNames", func(t *testing.T) {

t.Run("issue1", func(t *testing.T) {
actual, _ := conn.GetAssociatedRefNames("a97e9630426df5d34ca9ee77ae1159bdfd5ff8f0")
actual, _ := conn.GetAssociatedRefNames(context.Background(), "a97e9630426df5d34ca9ee77ae1159bdfd5ff8f0")
assert.Equal(t,
stub.readFile("git", "abranch", "issue1"),
actual,
)
})

t.Run("main_issue1", func(t *testing.T) {
actual, _ := conn.GetAssociatedRefNames("6ebe3d30d23531af56bd23b5a098d3ccae2a534a")
actual, _ := conn.GetAssociatedRefNames(context.Background(), "6ebe3d30d23531af56bd23b5a098d3ccae2a534a")
assert.Equal(t,
stub.readFile("git", "abranch", "main_issue1"),
actual,
Expand All @@ -71,12 +72,12 @@ func Test_RepoBasic(t *testing.T) {
})

t.Run("GetUncommittedChanges", func(t *testing.T) {
actual, _ := conn.GetUncommittedChanges()
actual, _ := conn.GetUncommittedChanges(context.Background())
assert.Equal(t, "A README.md\n", actual)
})

t.Run("GetConfig", func(t *testing.T) {
actual, _ := conn.GetConfig("branch.main.merge")
actual, _ := conn.GetConfig(context.Background(), "branch.main.merge")
assert.Equal(t,
stub.readFile("git", "configMerge", "main"),
actual,
Expand Down
26 changes: 13 additions & 13 deletions conn/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *Stub) CheckRepos(err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
CheckRepos(gomock.Any(), gomock.Any()).
CheckRepos(gomock.Any(), gomock.Any(), gomock.Any()).
Return(err),
conf,
)
Expand All @@ -71,7 +71,7 @@ func (s *Stub) GetRemoteNames(filename string, err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
GetRemoteNames().
GetRemoteNames(gomock.Any()).
Return(s.readFile("git", "remote", filename), err),
conf,
)
Expand All @@ -83,7 +83,7 @@ func (s *Stub) GetSshConfig(filename string, err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
GetSshConfig(gomock.Any()).
GetSshConfig(gomock.Any(), gomock.Any()).
Return(s.readFile("ssh", "config", filename), err),
conf,
)
Expand All @@ -95,7 +95,7 @@ func (s *Stub) GetRepoNames(filename string, err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
GetRepoNames(gomock.Any(), gomock.Any()).
GetRepoNames(gomock.Any(), gomock.Any(), gomock.Any()).
Return(s.readFile("gh", "repo", filename), err),
conf,
)
Expand All @@ -106,7 +106,7 @@ func (s *Stub) GetBranchNames(filename string, err error, conf *Conf) *Stub {
s.t.Helper()
configure(
s.Conn.EXPECT().
GetBranchNames().
GetBranchNames(gomock.Any()).
Return(s.readFile("git", "branch", filename), err),
conf,
)
Expand All @@ -117,7 +117,7 @@ func (s *Stub) GetMergedBranchNames(filename string, err error, conf *Conf) *Stu
s.t.Helper()
configure(
s.Conn.EXPECT().
GetMergedBranchNames("origin", "main").
GetMergedBranchNames(gomock.Any(), "origin", "main").
Return(s.readFile("git", "branchMerged", filename), err),
conf,
)
Expand All @@ -129,7 +129,7 @@ func (s *Stub) GetAssociatedRefNames(stubs []AssociatedBranchNamesStub, err erro
for _, stub := range stubs {
configure(
s.Conn.EXPECT().
GetAssociatedRefNames(stub.Oid).
GetAssociatedRefNames(gomock.Any(), stub.Oid).
Return(s.readFile("git", "abranch", stub.Filename), err),
conf,
)
Expand All @@ -142,7 +142,7 @@ func (s *Stub) GetLog(stubs []LogStub, err error, conf *Conf) *Stub {
for _, stub := range stubs {
configure(
s.Conn.EXPECT().
GetLog(stub.BranchName).
GetLog(gomock.Any(), stub.BranchName).
Return(s.readFile("git", "log", stub.Filename), err),
conf,
)
Expand All @@ -155,7 +155,7 @@ func (s *Stub) GetPullRequests(filename string, err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
GetPullRequests(gomock.Any(), gomock.Any(), gomock.Any()).
GetPullRequests(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(s.readFile("gh", "pr", filename), err),
conf,
)
Expand All @@ -167,7 +167,7 @@ func (s *Stub) GetUncommittedChanges(uncommittedChanges string, err error, conf
configure(
s.Conn.
EXPECT().
GetUncommittedChanges().
GetUncommittedChanges(gomock.Any()).
Return(uncommittedChanges, err),
conf,
)
Expand All @@ -180,7 +180,7 @@ func (s *Stub) GetConfig(stubs []ConfigStub, err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
GetConfig(stub.BranchName).
GetConfig(gomock.Any(), stub.BranchName).
Return(s.readFile("git", "configMerge", stub.Filename), err),
conf,
)
Expand All @@ -193,7 +193,7 @@ func (s *Stub) CheckoutBranch(err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
CheckoutBranch(gomock.Any()).
CheckoutBranch(gomock.Any(), gomock.Any()).
Return("", err),
conf,
)
Expand All @@ -205,7 +205,7 @@ func (s *Stub) DeleteBranches(err error, conf *Conf) *Stub {
configure(
s.Conn.
EXPECT().
DeleteBranches(gomock.Any()).
DeleteBranches(gomock.Any(), gomock.Any()).
Return("", err),
conf,
)
Expand Down
Loading

0 comments on commit f82212c

Please sign in to comment.