Fix possible race condition on timeout

This commit is contained in:
Iwasaki Yudai 2017-08-13 13:40:00 +09:00
parent 9b8d2d5ed5
commit 2a2a034788
3 changed files with 89 additions and 37 deletions

View file

@ -10,9 +10,8 @@ import (
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
noesctmpl "text/template"
"time"
"github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/websocket"
@ -81,12 +80,9 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
opt(opts)
}
// wg and connections can be incosistent because they are handled nonatomically
wg := new(sync.WaitGroup) // to wait all connections to be closed
connections := new(int64) // number of active connections
counter := newCounter(time.Duration(server.options.Timeout) * time.Second)
url := server.setupURL()
handlers := server.setupHandlers(cctx, cancel, url, connections, wg)
handlers := server.setupHandlers(cctx, cancel, url, counter)
srv, err := server.setupHTTPServer(handlers, url)
if err != nil {
return errors.Wrapf(err, "failed to setup an HTTP server")
@ -138,11 +134,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
err = cctx.Err()
}
conn := atomic.LoadInt64(connections)
conn := counter.count()
if conn > 0 {
log.Printf("Waiting for %d connections to be closed", conn)
}
wg.Wait()
counter.wait()
return err
}
@ -162,7 +158,7 @@ func (server *Server) setupURL() *url.URL {
return &url.URL{Scheme: scheme, Host: host, Path: path}
}
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler {
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, counter *counter) http.Handler {
staticFileHandler := http.FileServer(
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
)
@ -184,7 +180,7 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu
wsMux := http.NewServeMux()
wsMux.Handle("/", siteHandler)
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg))
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, counter))
siteHandler = http.Handler(wsMux)
return server.wrapLogger(siteHandler)