From 25627da86fb7b2e2d0f40ab7da3488943936cb39 Mon Sep 17 00:00:00 2001 From: Iwasaki Yudai Date: Fri, 21 Aug 2015 18:22:08 +0900 Subject: [PATCH] Restructure handler function --- app/app.go | 247 +++++++++++++++++++++++++++++------------------------ 1 file changed, 137 insertions(+), 110 deletions(-) diff --git a/app/app.go b/app/app.go index 11f7056..fe62ec7 100644 --- a/app/app.go +++ b/app/app.go @@ -7,6 +7,7 @@ import ( "log" "math/big" "net/http" + "os" "os/exec" "strconv" "strings" @@ -21,6 +22,8 @@ import ( type App struct { options Options + + upgrader *websocket.Upgrader } type Options struct { @@ -35,6 +38,12 @@ type Options struct { func New(options Options) *App { return &App{ options: options, + + upgrader: &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + Subprotocols: []string{"gotty"}, + }, } } @@ -77,7 +86,7 @@ func (app *App) Run() error { fs := http.StripPrefix(path, http.FileServer(&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "bindata"})) http.Handle(path, fs) - http.HandleFunc(path+"ws", app.generateHandler()) + http.HandleFunc(path+"ws", app.handler) endpoint := app.options.Address + ":" + app.options.Port log.Printf("Server is running at %s, command: %s", endpoint+path, strings.Join(app.options.Command, " ")) @@ -94,130 +103,131 @@ func (app *App) Run() error { return nil } -func (app *App) generateHandler() func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - log.Printf("New client connected: %s", r.RemoteAddr) +func (app *App) handler(w http.ResponseWriter, r *http.Request) { + log.Printf("New client connected: %s", r.RemoteAddr) - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - Subprotocols: []string{"gotty"}, - } + if r.Method != "GET" { + http.Error(w, "Method not allowed", 405) + return + } - if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) - return - } + conn, err := app.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("Failed to upgrade connection") + return + } - conn, err := upgrader.Upgrade(w, r, nil) + cmd := exec.Command(app.options.Command[0], app.options.Command[1:]...) + ptyIo, err := pty.Start(cmd) + if err != nil { + log.Print("Failed to execute command") + return + } + log.Printf("Command is running for client %s with PID %d", r.RemoteAddr, cmd.Process.Pid) + + context := &clientContext{ + request: r, + connection: conn, + command: cmd, + pty: ptyIo, + } + + app.goHandleConnection(context) +} + +func (app *App) goHandleConnection(context *clientContext) { + exit := make(chan bool, 2) + + go func() { + defer func() { exit <- true }() + + app.processSend(context) + }() + + go func() { + defer func() { exit <- true }() + + app.processReceive(context) + }() + + go func() { + <-exit + context.command.Wait() + context.connection.Close() + log.Printf("Connection closed: %s", context.request.RemoteAddr) + }() +} + +func (app *App) processSend(context *clientContext) { + buf := make([]byte, 1024) + utf8f := utf8reader.New(context.pty) + + for { + size, err := utf8f.Read(buf) if err != nil { - log.Print("Failed to upgrade connection") + log.Printf("Command exited for: %s", context.request.RemoteAddr) return } - cmd := exec.Command(app.options.Command[0], app.options.Command[1:]...) - fio, err := pty.Start(cmd) - log.Printf("Command is running for client %s with PID %d", r.RemoteAddr, cmd.Process.Pid) + writer, err := context.connection.NextWriter(websocket.TextMessage) if err != nil { - log.Print("Failed to execute command") return } - exit := make(chan bool, 2) - - go func() { - defer func() { exit <- true }() - - buf := make([]byte, 1024) - utf8f := utf8reader.New(fio) - - for { - size, err := utf8f.Read(buf) - if err != nil { - log.Printf("Command exited for: %s", r.RemoteAddr) - return - } - - writer, err := conn.NextWriter(websocket.TextMessage) - if err != nil { - return - } - - writer.Write(buf[:size]) - writer.Close() - } - }() - - go func() { - defer func() { exit <- true }() - - for { - _, data, err := conn.ReadMessage() - if err != nil { - return - } - - switch data[0] { - case Input: - if !app.options.PermitWrite { - break - } - - _, err := fio.Write(data[1:]) - if err != nil { - return - } - - case ResizeTerminal: - var args argResizeTerminal - err = json.Unmarshal(data[1:], &args) - if err != nil { - log.Print("Malformed remote command") - return - } - - window := struct { - row uint16 - col uint16 - x uint16 - y uint16 - }{ - uint16(args.Rows), - uint16(args.Columns), - 0, - 0, - } - syscall.Syscall( - syscall.SYS_IOCTL, - fio.Fd(), - syscall.TIOCSWINSZ, - uintptr(unsafe.Pointer(&window)), - ) - - default: - log.Print("Unknown message type") - return - } - } - }() - - go func() { - <-exit - cmd.Wait() - conn.Close() - log.Printf("Connection closed: %s", r.RemoteAddr) - }() + writer.Write(buf[:size]) + writer.Close() } } -const ( - Input = '0' - ResizeTerminal = '1' -) +func (app *App) processReceive(context *clientContext) { + for { + _, data, err := context.connection.ReadMessage() + if err != nil { + return + } -type argResizeTerminal struct { - Columns float64 - Rows float64 + switch data[0] { + case Input: + if !app.options.PermitWrite { + break + } + + _, err := context.pty.Write(data[1:]) + if err != nil { + return + } + + case ResizeTerminal: + var args argResizeTerminal + err = json.Unmarshal(data[1:], &args) + if err != nil { + log.Print("Malformed remote command") + return + } + + window := struct { + row uint16 + col uint16 + x uint16 + y uint16 + }{ + uint16(args.Rows), + uint16(args.Columns), + 0, + 0, + } + syscall.Syscall( + syscall.SYS_IOCTL, + context.pty.Fd(), + syscall.TIOCSWINSZ, + uintptr(unsafe.Pointer(&window)), + ) + + default: + log.Print("Unknown message type") + return + } + } } func generateRandomString(length int) string { @@ -230,3 +240,20 @@ func generateRandomString(length int) string { } return string(n) } + +const ( + Input = '0' + ResizeTerminal = '1' +) + +type argResizeTerminal struct { + Columns float64 + Rows float64 +} + +type clientContext struct { + request *http.Request + connection *websocket.Conn + command *exec.Cmd + pty *os.File +}