diff options
Diffstat (limited to 'api')
-rw-r--r-- | api/status_test.go | 7 | ||||
-rw-r--r-- | api/user_test.go | 5 | ||||
-rw-r--r-- | api/web_conn.go | 38 | ||||
-rw-r--r-- | api/web_hub.go | 4 | ||||
-rw-r--r-- | api/websocket.go | 2 | ||||
-rw-r--r-- | api/websocket_router.go | 31 | ||||
-rw-r--r-- | api/websocket_test.go | 112 |
7 files changed, 189 insertions, 10 deletions
diff --git a/api/status_test.go b/api/status_test.go index ffc946817..2aa866a47 100644 --- a/api/status_test.go +++ b/api/status_test.go @@ -22,6 +22,11 @@ func TestStatuses(t *testing.T) { defer WebSocketClient.Close() WebSocketClient.Listen() + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewId() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN} rteam, _ := Client.CreateTeam(&team) @@ -75,7 +80,7 @@ func TestStatuses(t *testing.T) { } if status, ok := resp.Data[th.BasicUser2.Id]; !ok { - t.Log(len(resp.Data)) + t.Log(resp.Data) t.Fatal("should have had user status") } else if status != model.STATUS_ONLINE { t.Log(status) diff --git a/api/user_test.go b/api/user_test.go index 57f4729da..20c555931 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -1794,6 +1794,11 @@ func TestUserTyping(t *testing.T) { defer WebSocketClient.Close() WebSocketClient.Listen() + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + WebSocketClient.UserTyping("", "") time.Sleep(300 * time.Millisecond) if resp := <-WebSocketClient.ResponseChannel; resp.Error.Id != "api.websocket_handler.invalid_param.app_error" { diff --git a/api/web_conn.go b/api/web_conn.go index 7f3c1f875..52b5ba9de 100644 --- a/api/web_conn.go +++ b/api/web_conn.go @@ -15,9 +15,10 @@ import ( ) const ( - WRITE_WAIT = 30 * time.Second - PONG_WAIT = 100 * time.Second - PING_PERIOD = (PONG_WAIT * 6) / 10 + WRITE_WAIT = 30 * time.Second + PONG_WAIT = 100 * time.Second + PING_PERIOD = (PONG_WAIT * 6) / 10 + AUTH_TIMEOUT = 5 * time.Second ) type WebConn struct { @@ -32,7 +33,9 @@ type WebConn struct { } func NewWebConn(c *Context, ws *websocket.Conn) *WebConn { - go SetStatusOnline(c.Session.UserId, c.Session.Id, false) + if len(c.Session.UserId) > 0 { + go SetStatusOnline(c.Session.UserId, c.Session.Id, false) + } return &WebConn{ Send: make(chan model.WebSocketMessage, 256), @@ -53,7 +56,9 @@ func (c *WebConn) readPump() { c.WebSocket.SetReadDeadline(time.Now().Add(PONG_WAIT)) c.WebSocket.SetPongHandler(func(string) error { c.WebSocket.SetReadDeadline(time.Now().Add(PONG_WAIT)) - go SetStatusAwayIfNeeded(c.UserId, false) + if c.isAuthenticated() { + go SetStatusAwayIfNeeded(c.UserId, false) + } return nil }) @@ -64,7 +69,7 @@ func (c *WebConn) readPump() { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { l4g.Debug(fmt.Sprintf("websocket.read: client side closed socket userId=%v", c.UserId)) } else { - l4g.Debug(fmt.Sprintf("websocket.read: cannot read, closing websocket for userId=%v error=%v", c.UserId, err.Error())) + l4g.Debug(fmt.Sprintf("websocket.read: closing websocket for userId=%v error=%v", c.UserId, err.Error())) } return @@ -76,9 +81,11 @@ func (c *WebConn) readPump() { func (c *WebConn) writePump() { ticker := time.NewTicker(PING_PERIOD) + authTicker := time.NewTicker(AUTH_TIMEOUT) defer func() { ticker.Stop() + authTicker.Stop() c.WebSocket.Close() }() @@ -97,7 +104,7 @@ func (c *WebConn) writePump() { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { l4g.Debug(fmt.Sprintf("websocket.send: client side closed socket userId=%v", c.UserId)) } else { - l4g.Debug(fmt.Sprintf("websocket.send: cannot send, closing websocket for userId=%v, error=%v", c.UserId, err.Error())) + l4g.Debug(fmt.Sprintf("websocket.send: closing websocket for userId=%v, error=%v", c.UserId, err.Error())) } return @@ -110,11 +117,18 @@ func (c *WebConn) writePump() { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { l4g.Debug(fmt.Sprintf("websocket.ticker: client side closed socket userId=%v", c.UserId)) } else { - l4g.Debug(fmt.Sprintf("websocket.ticker: cannot read, closing websocket for userId=%v error=%v", c.UserId, err.Error())) + l4g.Debug(fmt.Sprintf("websocket.ticker: closing websocket for userId=%v error=%v", c.UserId, err.Error())) } return } + + case <-authTicker.C: + if c.SessionToken == "" { + l4g.Debug(fmt.Sprintf("websocket.authTicker: did not authenticate ip=%v", c.WebSocket.RemoteAddr())) + return + } + authTicker.Stop() } } } @@ -122,10 +136,18 @@ func (c *WebConn) writePump() { func (webCon *WebConn) InvalidateCache() { webCon.AllChannelMembers = nil webCon.LastAllChannelMembersTime = 0 +} +func (webCon *WebConn) isAuthenticated() bool { + return webCon.SessionToken != "" } func (webCon *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool { + // IMPORTANT: Do not send event if WebConn does not have a session + if !webCon.isAuthenticated() { + return false + } + // If the event is destined to a specific user if len(msg.Broadcast.UserId) > 0 && webCon.UserId != msg.Broadcast.UserId { return false diff --git a/api/web_hub.go b/api/web_hub.go index 23c01eb1b..dfbdf3838 100644 --- a/api/web_hub.go +++ b/api/web_hub.go @@ -156,6 +156,10 @@ func (h *Hub) Start() { close(webCon.Send) } + if len(userId) == 0 { + continue + } + found := false for webCon := range h.connections { if userId == webCon.UserId { diff --git a/api/websocket.go b/api/websocket.go index 34d95f705..1c3277497 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -17,7 +17,7 @@ const ( func InitWebSocket() { l4g.Debug(utils.T("api.web_socket.init.debug")) - BaseRoutes.Users.Handle("/websocket", ApiUserRequiredTrustRequester(connect)).Methods("GET") + BaseRoutes.Users.Handle("/websocket", ApiAppHandlerTrustRequester(connect)).Methods("GET") HubStart() } diff --git a/api/websocket_router.go b/api/websocket_router.go index 34b576464..bdbd9f4d9 100644 --- a/api/websocket_router.go +++ b/api/websocket_router.go @@ -37,6 +37,37 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque return } + if r.Action == model.WEBSOCKET_AUTHENTICATION_CHALLENGE { + token, ok := r.Data["token"].(string) + if !ok { + conn.WebSocket.Close() + return + } + + session := GetSession(token) + + if session == nil || session.IsExpired() { + conn.WebSocket.Close() + } else { + go SetStatusOnline(session.UserId, session.Id, false) + + conn.SessionToken = session.Token + conn.UserId = session.UserId + + resp := model.NewWebSocketResponse(model.STATUS_OK, r.Seq, nil) + resp.DoPreComputeJson() + conn.Send <- resp + } + + return + } + + if conn.SessionToken == "" { + err := model.NewLocAppError("ServeWebSocket", "api.web_socket_router.not_authenticated.app_error", nil, "") + wr.ReturnWebSocketError(conn, r, err) + return + } + var handler *webSocketHandler if h, ok := wr.handlers[r.Action]; !ok { err := model.NewLocAppError("ServeWebSocket", "api.web_socket_router.bad_action.app_error", nil, "") diff --git a/api/websocket_test.go b/api/websocket_test.go index b7ca4b691..144c1a39b 100644 --- a/api/websocket_test.go +++ b/api/websocket_test.go @@ -4,12 +4,116 @@ package api import ( + "encoding/json" + "net/http" "testing" "time" + "github.com/gorilla/websocket" "github.com/mattermost/platform/model" ) +func TestWebSocketAuthentication(t *testing.T) { + th := Setup().InitBasic() + WebSocketClient, err := th.CreateWebSocketClient() + if err != nil { + t.Fatal(err) + } + WebSocketClient.Listen() + + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + + WebSocketClient.SendMessage("ping", nil) + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Data["text"].(string) != "pong" { + t.Fatal("wrong response") + } + + WebSocketClient.Close() + + authToken := WebSocketClient.AuthToken + WebSocketClient.AuthToken = "junk" + if err := WebSocketClient.Connect(); err != nil { + t.Fatal(err) + } + WebSocketClient.Listen() + + if resp := <-WebSocketClient.ResponseChannel; resp != nil { + t.Fatal("should have closed") + } + + WebSocketClient.Close() + + if conn, _, err := websocket.DefaultDialer.Dial(WebSocketClient.ApiUrl+"/users/websocket", nil); err != nil { + t.Fatal("should have connected") + } else { + req := &model.WebSocketRequest{} + req.Seq = 1 + req.Action = "ping" + conn.WriteJSON(req) + + closedAutomatically := false + hitNotAuthedError := false + + go func() { + time.Sleep(10 * time.Second) + conn.Close() + + if !closedAutomatically { + t.Fatal("should have closed automatically in 5 seconds") + } + }() + + for { + if _, rawMsg, err := conn.ReadMessage(); err != nil { + closedAutomatically = true + conn.Close() + break + } else { + var response model.WebSocketResponse + if err := json.Unmarshal(rawMsg, &response); err != nil && !response.IsValid() { + t.Fatal("should not have failed") + } else { + if response.Error == nil || response.Error.Id != "api.web_socket_router.not_authenticated.app_error" { + t.Log(response.Error.Id) + t.Fatal("wrong error") + continue + } + + hitNotAuthedError = true + } + } + } + + if !hitNotAuthedError { + t.Fatal("should have received a not authenticated response") + } + } + + header := http.Header{} + header.Set(model.HEADER_AUTH, "BEARER "+authToken) + if conn, _, err := websocket.DefaultDialer.Dial(WebSocketClient.ApiUrl+"/users/websocket", header); err != nil { + t.Fatal("should have connected") + } else { + if _, rawMsg, err := conn.ReadMessage(); err != nil { + t.Fatal("should not have closed automatically") + } else { + var event model.WebSocketEvent + if err := json.Unmarshal(rawMsg, &event); err != nil && !event.IsValid() { + t.Fatal("should not have failed") + } else if event.Event != model.WEBSOCKET_EVENT_HELLO { + t.Log(event.ToJson()) + t.Fatal("should have helloed") + } + } + + conn.Close() + } +} + func TestWebSocket(t *testing.T) { th := Setup().InitBasic() WebSocketClient, err := th.CreateWebSocketClient() @@ -29,6 +133,9 @@ func TestWebSocket(t *testing.T) { WebSocketClient.Listen() time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } WebSocketClient.SendMessage("ping", nil) time.Sleep(300 * time.Millisecond) @@ -78,6 +185,11 @@ func TestWebSocketEvent(t *testing.T) { WebSocketClient.Listen() + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + omitUser := make(map[string]bool, 1) omitUser["somerandomid"] = true evt1 := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_TYPING, "", th.BasicChannel.Id, "", omitUser) |