diff --git a/Makefile b/Makefile index 32c9999..21330fd 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,11 @@ .PHONY: run run: - go run cmd/rterm/main.go + @go run cmd/rterm/main.go + +.PHONY: totp +totp: + @go run cmd/totp/main.go .PHONY: build build: - CGO_ENABLED=0 go build -o rterm cmd/rterm/main.go \ No newline at end of file + @CGO_ENABLED=0 go build -o rterm cmd/rterm/main.go \ No newline at end of file diff --git a/README.md b/README.md index c1f6faa..2b3e425 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ RTERM is a web-based remote control application that allows you to control your Inspired by [GoTTY](https://github.com/yudai/gotty)

- +

@@ -33,12 +33,10 @@ Inspired by [GoTTY](https://github.com/yudai/gotty) rterm.Register( mux, rterm.Command{ - Factory: func() (*command.Command, error) { - return command.New("bash", nil) - }, Name: "bash", Description: "Bash (Unix shell)", Writable: true, + AuthCheck: auth.NewBasic("123456"), }, ) @@ -51,7 +49,9 @@ Inspired by [GoTTY](https://github.com/yudai/gotty) } ``` Please check [example](cmd/rterm/main.go) for more information. - + + + 2. Prebuilt binary. diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..7721daf --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,35 @@ +package auth + +import "github.com/pquerna/otp/totp" + +type AuthCheck interface { + Verify(code string) (bool, error) +} + +func NewBasic(code string) AuthCheck { + return authBasic{ + code: code, + } +} + +type authBasic struct { + code string +} + +func (a authBasic) Verify(code string) (bool, error) { + return code == a.code, nil +} + +func NewTOTP(secret string) AuthCheck { + return authTOTP{ + secret: secret, + } +} + +type authTOTP struct { + secret string +} + +func (a authTOTP) Verify(code string) (bool, error) { + return totp.Validate(code, a.secret), nil +} diff --git a/cmd/rterm/main.go b/cmd/rterm/main.go index 829a18c..d9cee0d 100644 --- a/cmd/rterm/main.go +++ b/cmd/rterm/main.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/dev6699/rterm" - "github.com/dev6699/rterm/command" + "github.com/dev6699/rterm/auth" ) func main() { @@ -24,28 +24,25 @@ func run() error { rterm.Register( mux, rterm.Command{ - Factory: func() (*command.Command, error) { - return command.New("bash", nil) - }, Name: "bash", Description: "Bash (Unix shell)", Writable: true, + AuthCheck: auth.NewTOTP("F4ECH5IH72ECOFFN4INKHXA5AVKTS256"), + }, + rterm.Command{ + Name: "sh", + Description: "Shell", + Writable: true, + AuthCheck: auth.NewBasic("123456"), }, rterm.Command{ - Factory: func() (*command.Command, error) { - return command.New("htop", nil) - }, Name: "htop", Description: "Interactive system monitor process viewer and process manager", - Writable: false, }, rterm.Command{ - Factory: func() (*command.Command, error) { - return command.New("nvidia-smi", strings.Split("--query-gpu=utilization.gpu --format=csv -l 1", " ")) - }, Name: "nvidia-smi", + Args: strings.Split("--query-gpu=utilization.gpu --format=csv -l 1", " "), Description: "Monitors and outputs the GPU utilization percentage every second", - Writable: false, }, ) diff --git a/cmd/totp/main.go b/cmd/totp/main.go new file mode 100644 index 0000000..e900954 --- /dev/null +++ b/cmd/totp/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "image" + "log" + + "github.com/pquerna/otp/totp" +) + +func main() { + key, err := totp.Generate(totp.GenerateOpts{ + AccountName: "dev6699", + Issuer: "rterm", + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println("Secret: ", key.Secret()) + + img, err := key.Image(45, 45) + if err != nil { + log.Fatal(err) + } + fmt.Println("QR Code:") + printQR(img) +} + +func printQR(img image.Image) { + bounds := img.Bounds() + width := bounds.Max.X + height := bounds.Max.Y + + for y := 0; y < height; y++ { + for x := 0; x < width; x++ { + color := img.At(x, y) + r, g, b, _ := color.RGBA() + grayscale := (r + g + b) / 3 + if grayscale > 0x7FFF { + fmt.Print(" ") + } else { + fmt.Print("██") + } + } + fmt.Print("\n") + } +} diff --git a/docs/auth.png b/docs/auth.png new file mode 100644 index 0000000..e1afbc6 Binary files /dev/null and b/docs/auth.png differ diff --git a/screenshot.png b/docs/index.png similarity index 100% rename from screenshot.png rename to docs/index.png diff --git a/rterm.gif b/docs/rterm.gif similarity index 100% rename from rterm.gif rename to docs/rterm.gif diff --git a/go.mod b/go.mod index ab08289..21c31f5 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,8 @@ require ( github.com/gorilla/websocket v1.5.1 ) -require golang.org/x/net v0.17.0 // indirect +require ( + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect + github.com/pquerna/otp v1.4.0 // indirect + golang.org/x/net v0.17.0 // indirect +) diff --git a/go.sum b/go.sum index 7b8f512..ceaf838 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,14 @@ +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.4.0 h1:wZvl1TIVxKRThZIBiwOOHOGP/1+nZyWBil9Y2XNEDzg= +github.com/pquerna/otp v1.4.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= diff --git a/rterm.go b/rterm.go index 55ecb99..996e8de 100644 --- a/rterm.go +++ b/rterm.go @@ -13,7 +13,10 @@ import ( "sort" "strings" + "github.com/dev6699/rterm/auth" + "github.com/dev6699/rterm/command" "github.com/dev6699/rterm/server" + "github.com/dev6699/rterm/tty" "github.com/dev6699/rterm/ui" "github.com/gorilla/websocket" ) @@ -48,9 +51,7 @@ func SetPrefix(prefix string) { } // Check if the prefix ends with "/" - if strings.HasSuffix(prefix, "/") { - prefix = strings.TrimSuffix(prefix, "/") - } + prefix = strings.TrimSuffix(prefix, "/") defaultPrefix = prefix } @@ -61,13 +62,16 @@ func SetWSUpgrader(u websocket.Upgrader) { } type Command struct { - Factory server.CommandFactory // Name of the command, will be used as the url to execute the command Name string + // Args of the the command + Args []string // Description of the command Description string - // Writable indicate whether server should process inputs from clients. + // Writable indicate whether server should process inputs from clients Writable bool + // AuthCheck acts as pre-verification step before starts agent process + AuthCheck auth.AuthCheck } // Register binds all command handlers to the http mux. @@ -106,7 +110,13 @@ func Register(mux *http.ServeMux, commands ...Command) { http.NotFound(w, r) return } - server.HandleWebSocket(&wsUpgrader, cmd.Factory, cmd.Writable)(w, r) + server.HandleWebSocket(&wsUpgrader, server.Command{ + Factory: func() (tty.Agent, error) { + return command.New(cmd.Name, cmd.Args) + }, + Writable: cmd.Writable, + AuthCheck: cmd.AuthCheck, + })(w, r) }) } diff --git a/server/server.go b/server/server.go index 27d4ae6..98ea3f0 100644 --- a/server/server.go +++ b/server/server.go @@ -4,14 +4,18 @@ import ( "log" "net/http" - "github.com/dev6699/rterm/command" + "github.com/dev6699/rterm/auth" "github.com/dev6699/rterm/tty" "github.com/gorilla/websocket" ) -type CommandFactory = func() (*command.Command, error) +type Command struct { + Factory tty.AgentFactory + AuthCheck auth.AuthCheck + Writable bool +} -func HandleWebSocket(wsUpgrader *websocket.Upgrader, cmdFac CommandFactory, writable bool) func(http.ResponseWriter, *http.Request) { +func HandleWebSocket(wsUpgrader *websocket.Upgrader, cmd Command) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { conn, err := wsUpgrader.Upgrade(w, r, nil) @@ -21,13 +25,10 @@ func HandleWebSocket(wsUpgrader *websocket.Upgrader, cmdFac CommandFactory, writ } defer conn.Close() - cmd, err := cmdFac() - if err != nil { - log.Printf("server: failed to start command; err = %v", err) - return - } + t := tty.New(WSController{Conn: conn}, cmd.Factory) + t.WithWrite(cmd.Writable) + t.WithAuthCheck(cmd.AuthCheck) - t := tty.New(WSController{Conn: conn}, cmd, writable) err = t.Run(r.Context()) if err != nil { log.Printf("server: socket connection closed; err = %v", err) diff --git a/tty/message.go b/tty/message.go index 1b84ea9..6cced2b 100644 --- a/tty/message.go +++ b/tty/message.go @@ -6,4 +6,9 @@ const ( Input Message = '0' Output Message = '1' ResizeTerminal Message = '2' + + Auth Message = 'a' + AuthTry Message = 'b' + AuthOK Message = 'c' + AuthFailed Message = 'd' ) diff --git a/tty/tty.go b/tty/tty.go index 6d0dd12..b780dcd 100644 --- a/tty/tty.go +++ b/tty/tty.go @@ -6,25 +6,41 @@ import ( "encoding/json" "fmt" "sync" + + "github.com/dev6699/rterm/auth" ) +type AgentFactory = func() (Agent, error) + type TTY struct { - controller Controller - agent Agent + controller Controller + agentFactory AgentFactory + // agent will be nil unless authCheck has passed + agent Agent + + // mutex to ensure no concurrent write to controller mut sync.Mutex bufferSize int writable bool + authCheck auth.AuthCheck } -func New(controller Controller, agent Agent, writable bool) *TTY { +func New(controller Controller, agentFactory AgentFactory) *TTY { return &TTY{ - controller: controller, - agent: agent, - bufferSize: 1024, - writable: writable, + controller: controller, + agentFactory: agentFactory, + bufferSize: 1024, } } +func (t *TTY) WithWrite(b bool) { + t.writable = b +} + +func (t *TTY) WithAuthCheck(c auth.AuthCheck) { + t.authCheck = c +} + func (t *TTY) Run(ctx context.Context) error { err := t.initialize() if err != nil { @@ -36,6 +52,10 @@ func (t *TTY) Run(ctx context.Context) error { go func() { buf := make([]byte, t.bufferSize) for { + if t.agent == nil { + continue + } + n, err := t.agent.Read(buf) if err != nil { errCh <- err @@ -78,6 +98,28 @@ func (t *TTY) Run(ctx context.Context) error { } func (t *TTY) initialize() error { + if t.authCheck != nil { + return t.controllerWrite(Auth, nil) + } + + err := t.createAgent() + if err != nil { + return err + } + return t.controllerWrite(AuthOK, nil) +} + +func (t *TTY) createAgent() error { + if t.agent != nil { + return nil + } + + var err error + t.agent, err = t.agentFactory() + if err != nil { + return err + } + return nil } @@ -93,6 +135,25 @@ func (t *TTY) handleControllerData(data []byte) error { msg := Message(data[0]) switch msg { + + case AuthTry: + code := data[1:] + pass, err := t.authCheck.Verify(string(code)) + if err != nil { + return err + } + + if pass { + var err error + t.agent, err = t.agentFactory() + if err != nil { + return err + } + t.controllerWrite(AuthOK, nil) + } else { + t.controllerWrite(AuthFailed, nil) + } + case Input: if !t.writable || len(data) <= 1 { return nil @@ -118,7 +179,7 @@ func (t *TTY) handleControllerData(data []byte) error { return t.agent.ResizeTerminal(r.Cols, r.Rows) default: - return fmt.Errorf("tty: unkown message type: %c", msg) + return fmt.Errorf("tty: unknown message type: %c", msg) } return nil diff --git a/tty/tty_test.go b/tty/tty_test.go deleted file mode 100644 index b3081f2..0000000 --- a/tty/tty_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package tty_test - -import ( - "bytes" - "context" - "encoding/base64" - "fmt" - "io" - "sync" - "testing" - - "github.com/dev6699/rterm/tty" -) - -type pipe struct { - *io.PipeReader - *io.PipeWriter -} - -func (p pipe) Read(d []byte) (n int, err error) { - if p.PipeReader != nil { - return p.PipeReader.Read(d) - } - - select {} -} - -func (p pipe) Write(d []byte) (n int, err error) { - if p.PipeWriter != nil { - return p.PipeWriter.Write(d) - } - select {} -} - -func (p pipe) ResizeTerminal(columns int, row int) error { - return nil -} - -func Test_AgentWrite(t *testing.T) { - agentReader, agentWriter := io.Pipe() - agent := pipe{ - PipeReader: agentReader, - } - - controllerReader, controllerWriter := io.Pipe() - controller := pipe{ - PipeWriter: controllerWriter, - } - - dt := tty.New(controller, agent) - - ctx, cancel := context.WithCancel(context.Background()) - var wg sync.WaitGroup - wg.Add(1) - go func(t *testing.T) { - defer wg.Done() - dt.Run(ctx) - }(t) - - message := []byte("foobar") - _, err := agentWriter.Write(message) - if err != nil { - t.Fatalf("agentWriter.Write(); err = %v", err) - } - - buf := make([]byte, 1024) - n, err := controllerReader.Read(buf) - if err != nil { - t.Fatalf("controllerReader.Read(); err = %v", err) - } - - if tty.Message(buf[0]) != tty.Output { - t.Fatalf("got message type = %c; want = %c", tty.Message(buf[0]), tty.Output) - } - - decoded := make([]byte, 1024) - n, err = base64.StdEncoding.Decode(decoded, buf[1:n]) - if err != nil { - t.Fatalf("base64.StdEncoding.Decode(); err = %v", err) - } - if !bytes.Equal(decoded[:n], message) { - t.Fatalf("got message = %s; want = %s", decoded[:n], message) - } - - cancel() - wg.Wait() -} - -func Test_ControllerWrite(t *testing.T) { - agentReader, agentWriter := io.Pipe() - agent := pipe{ - PipeReader: agentReader, - PipeWriter: agentWriter, - } - - controllerReader, controllerWriter := io.Pipe() - controller := pipe{ - PipeReader: controllerReader, - PipeWriter: controllerWriter, - } - - dt := tty.New(controller, agent) - - ctx, cancel := context.WithCancel(context.Background()) - var wg sync.WaitGroup - wg.Add(1) - go func(t *testing.T) { - defer wg.Done() - dt.Run(ctx) - }(t) - - message := []byte(fmt.Sprintf("%chello\n", tty.Input)) - _, err := controllerWriter.Write(message) - if err != nil { - t.Fatalf("controllerWriter.Write(); err = %v", err) - } - - buf := make([]byte, 1024) - n, err := agentReader.Read(buf) - if err != nil { - t.Fatalf("agentReader.Read(); err = %v", err) - } - - if !bytes.Equal(buf[:n], message[1:]) { - t.Fatalf("got message = %s; want = %s", buf[:n], message[1:]) - } - - cancel() - wg.Wait() -} diff --git a/ui/src/index.html b/ui/src/index.html index 6bf5325..b1edd57 100644 --- a/ui/src/index.html +++ b/ui/src/index.html @@ -18,7 +18,8 @@ height: 100vh; } - #terminal { + #terminal, + #auth { color: white; background: black; height: 100%; @@ -27,10 +28,53 @@ .xterm-viewport { overflow-y: auto !important; } + + #auth { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + } + + .digit-input { + width: 30px; + height: 40px; + font-size: 1.5em; + text-align: center; + margin: 5px; + } + + input::-webkit-outer-spin-button, + input::-webkit-inner-spin-button { + -webkit-appearance: none; + margin: 0; + } + + #result { + color: red; + } +
+

Enter 6-Digit Code

+
+ + + + + + +
+
+
diff --git a/ui/src/script.js b/ui/src/script.js index 82d549f..2ac1231 100644 --- a/ui/src/script.js +++ b/ui/src/script.js @@ -2,13 +2,35 @@ const MSG_INPUT = '0' const MSG_OUTPUT = '1' const MSG_RESIZE_TERMINAL = '2' +const MSG_AUTH = 'a' +const MSG_AUTH_TRY = 'b' +const MSG_AUTH_OK = 'c' +const MSG_AUTH_FAILED = 'd' + const terminalElement = document.getElementById('terminal') +terminalElement.style.display = 'none' +const authElement = document.getElementById('auth') +authElement.style.display = 'none' + +let terminal +function showTerminal() { + terminalElement.style.display = 'block' + terminal = new Terminal(); + const fitAddon = new FitAddon.FitAddon(); + terminal.loadAddon(fitAddon); + terminal.open(terminalElement); + fitAddon.fit(); + + socket.send(MSG_RESIZE_TERMINAL + JSON.stringify({ cols: terminal.cols, rows: terminal.rows })) + + terminal.onData((data) => { + socket.send(MSG_INPUT + data) + }) -const terminal = new Terminal(); -const fitAddon = new FitAddon.FitAddon(); -terminal.loadAddon(fitAddon); -terminal.open(terminalElement); -fitAddon.fit(); + terminal.onResize((data) => { + socket.send(MSG_RESIZE_TERMINAL + JSON.stringify(data)) + }) +} const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; const wsHost = window.location.hostname; @@ -17,17 +39,30 @@ const wsPath = window.location.pathname + '/ws'; const wsURL = wsProtocol + wsHost + wsPort + wsPath; const socket = new WebSocket(wsURL); -socket.addEventListener("open", () => { - socket.send(MSG_RESIZE_TERMINAL + JSON.stringify({ cols: terminal.cols, rows: terminal.rows })) -}); socket.addEventListener("message", (event) => { const message = event.data.slice(0, 1) - if (message !== MSG_OUTPUT) { - return + + switch (message) { + case MSG_AUTH: + authElement.style.display = 'flex' + document.getElementById('digit1').focus(); + break + + case MSG_AUTH_OK: + authElement.style.display = 'none' + showTerminal() + break + + case MSG_AUTH_FAILED: + clearDigits() + break + + case MSG_OUTPUT: + const data = atob(event.data.slice(1)) + terminal.write(data) + break } - const data = atob(event.data.slice(1)) - terminal.write(data) }); socket.addEventListener("close", () => { @@ -40,17 +75,40 @@ socket.addEventListener("error", () => { terminalElement.innerText = "Connection error" }) -terminal.onData((data) => { - socket.send(MSG_INPUT + data) -}) - -terminal.onResize((data) => { - socket.send(MSG_RESIZE_TERMINAL + JSON.stringify(data)) -}) function resize() { fitAddon.fit() terminal.scrollToBottom() } -window.addEventListener('resize', this.resize) \ No newline at end of file +window.addEventListener('resize', this.resize) + +function moveFocus(currentDigit) { + const currentInput = document.getElementById(`digit${currentDigit}`); + if (currentInput.value.length === 1) { + if (currentDigit < 6) { + document.getElementById('result').textContent = ''; + document.getElementById(`digit${currentDigit + 1}`).focus(); + } else { + submitCode() + } + } +} + +function submitCode() { + const digits = []; + for (let i = 1; i <= 6; i++) { + const digitInput = document.getElementById(`digit${i}`); + digits.push(digitInput.value); + } + const code = digits.join(''); + socket.send(MSG_AUTH_TRY + code) +} + +function clearDigits() { + for (let i = 1; i <= 6; i++) { + document.getElementById(`digit${i}`).value = ''; + } + document.getElementById('digit1').focus(); + document.getElementById('result').textContent = 'Invalid code'; +}