Skip to content

Commit

Permalink
zsh reinit fixes (#477)
Browse files Browse the repository at this point in the history
* reset command now initiates and completes async so there is feedback that something is happening when it takes a long time

* switch from standard rpc to rpciter

* checkpoint on reinit -- stream output, stats packet, logging to cmd pty, new endBytes for EOF

* make generic versions of endbytes scanner and channel output funcs

* update bash to use more modern state parsing (tricks learned from zsh)

* verbose mode, fix stats output message

* add a diff when verbose mode is on
  • Loading branch information
sawka authored Mar 19, 2024
1 parent accb74a commit 5616c9a
Show file tree
Hide file tree
Showing 11 changed files with 427 additions and 170 deletions.
11 changes: 6 additions & 5 deletions waveshell/pkg/packet/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,12 @@ func MakeLogPacket(entry wlog.LogEntry) *LogPacketType {
}

type ShellStatePacketType struct {
Type string `json:"type"`
ShellType string `json:"shelltype"`
RespId string `json:"respid,omitempty"`
State *ShellState `json:"state"`
Error string `json:"error,omitempty"`
Type string `json:"type"`
ShellType string `json:"shelltype"`
RespId string `json:"respid,omitempty"`
State *ShellState `json:"state"`
Stats *ShellStateStats `json:"stats"`
Error string `json:"error,omitempty"`
}

func (*ShellStatePacketType) GetType() string {
Expand Down
15 changes: 15 additions & 0 deletions waveshell/pkg/packet/shellstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ import (
const ShellStatePackVersion = 0
const ShellStateDiffPackVersion = 0

type ShellStateStats struct {
Version string `json:"version"`
AliasCount int `json:"aliascount"`
EnvCount int `json:"envcount"`
VarCount int `json:"varcount"`
FuncCount int `json:"funccount"`
HashVal string `json:"hashval"`
OutputSize int64 `json:"outputsize"`
StateSize int64 `json:"statesize"`
}

type ShellState struct {
Version string `json:"version"` // [type] [semver]
Cwd string `json:"cwd,omitempty"`
Expand All @@ -29,6 +40,10 @@ type ShellState struct {
HashVal string `json:"-"`
}

func (state ShellState) ApproximateSize() int64 {
return int64(len(state.Version) + len(state.Cwd) + len(state.ShellVars) + len(state.Aliases) + len(state.Funcs) + len(state.Error))
}

type ShellStateDiff struct {
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
BaseHash string `json:"basehash"`
Expand Down
29 changes: 27 additions & 2 deletions waveshell/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,10 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
appendSlashes(comps)
}
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
return
}

func (m *MServer) reinit(reqId string, shellType string) {
ssPk, err := shexec.MakeShellStatePacket(shellType)
ssPk, err := m.MakeShellStatePacket(reqId, shellType)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
return
Expand All @@ -262,6 +261,32 @@ func (m *MServer) reinit(reqId string, shellType string) {
m.Sender.SendPacket(ssPk)
}

func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType)
if err != nil {
return nil, err
}
rtnCh := make(chan shellapi.ShellStateOutput, 1)
go sapi.GetShellState(rtnCh)
for ssOutput := range rtnCh {
if ssOutput.Error != "" {
return nil, errors.New(ssOutput.Error)
}
if ssOutput.ShellState != nil {
rtn := packet.MakeShellStatePacket()
rtn.State = ssOutput.ShellState
rtn.Stats = ssOutput.Stats
return rtn, nil
}
if ssOutput.Output != nil {
dataPk := packet.MakeFileDataPacket(reqId)
dataPk.Data = ssOutput.Output
m.Sender.SendPacket(dataPk)
}
}
return nil, nil
}

func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
dirName := filepath.Dir(path)
baseName := filepath.Base(path)
Expand Down
90 changes: 60 additions & 30 deletions waveshell/pkg/shellapi/bashapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)

const BaseBashOpts = `set +m; set +H; shopt -s extglob`
Expand Down Expand Up @@ -48,7 +49,7 @@ func (b bashShellApi) GetShellType() string {
return packet.ShellType_bash
}

func (b bashShellApi) MakeExitTrap(fdNum int) string {
func (b bashShellApi) MakeExitTrap(fdNum int) (string, []byte) {
return MakeBashExitTrap(fdNum)
}

Expand Down Expand Up @@ -79,29 +80,15 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
}

func (b bashShellApi) GetShellState() chan ShellStateOutput {
ch := make(chan ShellStateOutput, 1)
defer close(ch)
ssPk, err := GetBashShellState()
if err != nil {
ch <- ShellStateOutput{
Status: ShellStateOutputStatus_Done,
Error: err.Error(),
}
return ch
}
ch <- ShellStateOutput{
Status: ShellStateOutputStatus_Done,
ShellState: ssPk,
}
return ch
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) {
GetBashShellState(outCh)
}

func (b bashShellApi) GetBaseShellOpts() string {
return BaseBashOpts
}

func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) {
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
return parseBashShellStateOutput(output)
}

Expand Down Expand Up @@ -130,8 +117,32 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
return rcBuf.String()
}

func GetBashShellStateCmd() string {
return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`)
func GetBashShellStateCmd(fdNum int) (string, []byte) {
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
endBytes = append(endBytes, '\n')
cmdStr := strings.TrimSpace(`
exec 2> /dev/null;
exec > [%OUTPUTFD%];
printf "\x00\x00";
[%BASHVERSIONCMD%];
printf "\x00\x00";
pwd;
printf "\x00\x00";
declare -p $(compgen -A variable);
printf "\x00\x00";
alias -p;
printf "\x00\x00";
declare -f;
printf "\x00\x00";
[%GITBRANCHCMD%];
printf "\x00\x00";
printf "[%ENDBYTES%]";
`)
cmdStr = strings.ReplaceAll(cmdStr, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
cmdStr = strings.ReplaceAll(cmdStr, "[%BASHVERSIONCMD%]", BashShellVersionCmdStr)
cmdStr = strings.ReplaceAll(cmdStr, "[%GITBRANCHCMD%]", GetGitBranchCmdStr)
cmdStr = strings.ReplaceAll(cmdStr, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
return cmdStr, endBytes
}

func execGetLocalBashShellVersion() string {
Expand All @@ -158,16 +169,34 @@ func GetLocalBashMajorVersion() string {
return localBashMajorVersion
}

func GetBashShellState() (*packet.ShellState, error) {
func GetBashShellState(outCh chan ShellStateOutput) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd()
defer close(outCh)
stateCmd, endBytes := GetBashShellStateCmd(StateOutputFdNum)
cmdStr := BaseBashOpts + "; " + stateCmd
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
outputBytes, err := RunSimpleCmdInPty(ecmd)
outputCh := make(chan []byte, 10)
var outputWg sync.WaitGroup
outputWg.Add(1)
go func() {
defer outputWg.Done()
for outputBytes := range outputCh {
outCh <- ShellStateOutput{Output: outputBytes}
}
}()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
outputWg.Wait()
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
return
}
rtn, stats, err := parseBashShellStateOutput(outputBytes)
if err != nil {
return nil, err
outCh <- ShellStateOutput{Error: err.Error()}
return
}
return parseBashShellStateOutput(outputBytes)
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
}

func GetLocalBashPath() string {
Expand All @@ -190,19 +219,20 @@ func GetLocalZshPath() string {
return "zsh"
}

func GetBashShellStateRedirectCommandStr(outputFdNum int) string {
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum)
func GetBashShellStateRedirectCommandStr(outputFdNum int) (string, []byte) {
cmdStr, endBytes := GetBashShellStateCmd(outputFdNum)
return cmdStr, endBytes
}

func MakeBashExitTrap(fdNum int) string {
stateCmd := GetBashShellStateRedirectCommandStr(fdNum)
func MakeBashExitTrap(fdNum int) (string, []byte) {
stateCmd, endBytes := GetBashShellStateRedirectCommandStr(fdNum)
fmtStr := `
_waveshell_exittrap () {
%s
}
trap _waveshell_exittrap EXIT
`
return fmt.Sprintf(fmtStr, stateCmd)
return fmt.Sprintf(fmtStr, stateCmd), endBytes
}

func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
Expand Down
44 changes: 28 additions & 16 deletions waveshell/pkg/shellapi/bashparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ import (
"mvdan.cc/sh/v3/syntax"
)

const (
BashSection_Ignored = iota
BashSection_Version
BashSection_Cwd
BashSection_Vars
BashSection_Aliases
BashSection_Funcs
BashSection_PVars
BashSection_EndBytes

BashSection_Count // must be last
)

type DeclareDeclType = shellenv.DeclareDeclType

func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
Expand Down Expand Up @@ -214,38 +227,37 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
return nil
}

func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_bash, outputBytes)
}
// 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6]
fields := bytes.Split(outputBytes, []byte{0, 0})
if len(fields) != 7 {
return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields))
sections := bytes.Split(outputBytes, []byte{0, 0})
if len(sections) != BashSection_Count {
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(sections))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(string(fields[1]))
rtn.Version = strings.TrimSpace(string(sections[BashSection_Version]))
if rtn.GetShellType() != packet.ShellType_bash {
return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
}
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
return nil, nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
}
cwdStr := string(fields[2])
cwdStr := string(sections[BashSection_Cwd])
if strings.HasSuffix(cwdStr, "\r\n") {
cwdStr = cwdStr[0 : len(cwdStr)-2]
} else if strings.HasSuffix(cwdStr, "\n") {
cwdStr = cwdStr[0 : len(cwdStr)-1]
} else {
cwdStr = strings.TrimSuffix(cwdStr, "\n")
}
rtn.Cwd = string(cwdStr)
err := bashParseDeclareOutput(rtn, fields[3], fields[6])
err := bashParseDeclareOutput(rtn, sections[BashSection_Vars], sections[BashSection_PVars])
if err != nil {
return nil, err
return nil, nil, err
}
rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n")
rtn.Aliases = strings.ReplaceAll(string(sections[BashSection_Aliases]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(sections[BashSection_Funcs]), "\r\n", "\n")
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
return rtn, nil
return rtn, nil, nil
}

func bashNormalize(d *DeclareDeclType) error {
Expand Down
Loading

0 comments on commit 5616c9a

Please sign in to comment.