-
Notifications
You must be signed in to change notification settings - Fork 0
/
smartpty.go
223 lines (192 loc) · 4.76 KB
/
smartpty.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
package smartpty
import (
"github.com/kr/pty"
"github.com/npat-efault/poller"
"golang.org/x/crypto/ssh/terminal"
"os"
"os/exec"
"os/signal"
"regexp"
"sync"
"syscall"
)
const (
bufferSize = 32768
)
// ExpressionCallback represents a SmartPTY callback function class
// Use this to react on matches in stdout data. Boolean value returned
// by the function is used to indicate if the match shouldn't be shown
// in stdout
type ExpressionCallback func(data []byte, tty *os.File) []byte
// SmartPTY represents the SmartPTY class
type SmartPTY struct {
cmd *exec.Cmd
callbacks []*cbDescriptor
signals chan os.Signal
tty *os.File
finished bool
stdinSync *sync.Mutex
cbSync *sync.Mutex
stdinPFD *poller.FD
stdinState *terminal.State
stdinBackup int
}
type cbDescriptor struct {
expr *regexp.Regexp
cb ExpressionCallback
count int
}
// Create method creates a new instance of SmartPTY based on exec.Cmd
func Create(cmd *exec.Cmd) *SmartPTY {
return &SmartPTY{
cmd,
make([]*cbDescriptor, 0),
make(chan os.Signal, 1),
nil,
false,
new(sync.Mutex),
new(sync.Mutex),
nil,
nil,
0,
}
}
// Always method sets a callback which will always
// be called when the given expression occurs in terminal stdout
func (sp *SmartPTY) Always(expression *regexp.Regexp, cb ExpressionCallback) {
sp.Times(expression, cb, -1)
}
// Once method sets a callback which will be called exactly once
// when the given expression occurs in terminal stdout
func (sp *SmartPTY) Once(expression *regexp.Regexp, cb ExpressionCallback) {
sp.Times(expression, cb, 1)
}
// Times method sets a callback which will be called
// when the given expression occurs in terminal stdout <times> times max.
// When maximum reactions reached the callback is disabled
func (sp *SmartPTY) Times(expression *regexp.Regexp, cb ExpressionCallback, times int) {
desc := &cbDescriptor{expression, cb, times}
sp.cbSync.Lock()
sp.callbacks = append(sp.callbacks, desc)
sp.cbSync.Unlock()
}
// Start starts the process configured
func (sp *SmartPTY) Start() error {
var err error
sp.tty, err = pty.Start(sp.cmd)
if err != nil {
return err
}
sp.stdinBackup, err = syscall.Dup(int(os.Stdin.Fd()))
if err != nil {
return err
}
go sp.processSignals()
go sp.processStdout()
go sp.processStdin()
return nil
}
func (sp *SmartPTY) processSignals() {
signal.Notify(sp.signals, syscall.SIGWINCH)
defer signal.Reset()
sp.signals <- syscall.SIGWINCH
for range sp.signals {
pty.InheritSize(os.Stdin, sp.tty)
}
}
func (sp *SmartPTY) processStdout() {
var displayBuffer []byte
buf := make([]byte, bufferSize)
shouldCompact := false
for !sp.finished {
n, err := sp.tty.Read(buf)
if err != nil {
// EOF
sp.finished = true
}
if n > 0 {
// copy data for the callback as we'll replace it shortly
displayBuffer = make([]byte, n)
copy(displayBuffer, buf[:n])
if len(sp.callbacks) > 0 {
// Preserve in-loop mutations
sp.cbSync.Lock()
callbacks := make([]*cbDescriptor, len(sp.callbacks))
copy(callbacks, sp.callbacks)
sp.cbSync.Unlock()
// searching for mathes
for _, cbd := range callbacks {
if cbd.count == 0 {
// this callback shouldn't be called anymore
shouldCompact = true
continue
}
if cbd.expr.Match(displayBuffer) {
// run the callback
sp.stdinSync.Lock()
displayBuffer = cbd.cb(displayBuffer, sp.tty)
sp.stdinSync.Unlock()
// decrement callback call counter
if cbd.count > 0 {
cbd.count--
}
}
}
}
os.Stdout.Write(displayBuffer)
if shouldCompact {
dfCallbacks := make([]*cbDescriptor, 0)
sp.cbSync.Lock()
for _, cbd := range sp.callbacks {
if cbd.count != 0 {
dfCallbacks = append(dfCallbacks, cbd)
}
}
sp.callbacks = dfCallbacks
sp.cbSync.Unlock()
}
}
}
}
func (sp *SmartPTY) processStdin() {
// Setup stdin to work in raw mode
stdinState, err := terminal.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
sp.finished = true
return
}
sp.stdinState = stdinState
sp.stdinPFD, err = poller.NewFD(int(os.Stdin.Fd()))
if err != nil {
sp.finished = true
return
}
buf := make([]byte, bufferSize)
for !sp.finished {
nr, er := sp.stdinPFD.Read(buf)
if nr > 0 {
sp.stdinSync.Lock()
nw, ew := sp.tty.Write(buf[:nr])
sp.stdinSync.Unlock()
if ew != nil {
// error writing to terminal
sp.finished = true
}
if nr != nw {
// short write
sp.finished = true
}
}
if er != nil {
sp.finished = true
}
}
}
// Close closes the whole process and shuts down all the goroutines
func (sp *SmartPTY) Close() {
sp.tty.Close()
close(sp.signals)
sp.stdinPFD.Close()
syscall.Dup2(sp.stdinBackup, int(os.Stdin.Fd()))
terminal.Restore(int(os.Stdin.Fd()), sp.stdinState)
}