diff options
-rw-r--r-- | api/status_test.go | 7 | ||||
-rw-r--r-- | api/user_test.go | 5 | ||||
-rw-r--r-- | api/web_conn.go | 38 | ||||
-rw-r--r-- | api/web_hub.go | 4 | ||||
-rw-r--r-- | api/websocket.go | 2 | ||||
-rw-r--r-- | api/websocket_router.go | 31 | ||||
-rw-r--r-- | api/websocket_test.go | 112 | ||||
-rw-r--r-- | model/websocket_client.go | 24 | ||||
-rw-r--r-- | model/websocket_message.go | 1 | ||||
-rw-r--r-- | webapp/client/websocket_client.jsx | 12 | ||||
-rw-r--r-- | webapp/tests/client_websocket.test.jsx | 49 | ||||
-rw-r--r-- | webapp/tests/test_helper.jsx | 20 |
12 files changed, 279 insertions, 26 deletions
diff --git a/api/status_test.go b/api/status_test.go index ffc946817..2aa866a47 100644 --- a/api/status_test.go +++ b/api/status_test.go @@ -22,6 +22,11 @@ func TestStatuses(t *testing.T) { defer WebSocketClient.Close() WebSocketClient.Listen() + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewId() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN} rteam, _ := Client.CreateTeam(&team) @@ -75,7 +80,7 @@ func TestStatuses(t *testing.T) { } if status, ok := resp.Data[th.BasicUser2.Id]; !ok { - t.Log(len(resp.Data)) + t.Log(resp.Data) t.Fatal("should have had user status") } else if status != model.STATUS_ONLINE { t.Log(status) diff --git a/api/user_test.go b/api/user_test.go index 57f4729da..20c555931 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -1794,6 +1794,11 @@ func TestUserTyping(t *testing.T) { defer WebSocketClient.Close() WebSocketClient.Listen() + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + WebSocketClient.UserTyping("", "") time.Sleep(300 * time.Millisecond) if resp := <-WebSocketClient.ResponseChannel; resp.Error.Id != "api.websocket_handler.invalid_param.app_error" { diff --git a/api/web_conn.go b/api/web_conn.go index 7f3c1f875..52b5ba9de 100644 --- a/api/web_conn.go +++ b/api/web_conn.go @@ -15,9 +15,10 @@ import ( ) const ( - WRITE_WAIT = 30 * time.Second - PONG_WAIT = 100 * time.Second - PING_PERIOD = (PONG_WAIT * 6) / 10 + WRITE_WAIT = 30 * time.Second + PONG_WAIT = 100 * time.Second + PING_PERIOD = (PONG_WAIT * 6) / 10 + AUTH_TIMEOUT = 5 * time.Second ) type WebConn struct { @@ -32,7 +33,9 @@ type WebConn struct { } func NewWebConn(c *Context, ws *websocket.Conn) *WebConn { - go SetStatusOnline(c.Session.UserId, c.Session.Id, false) + if len(c.Session.UserId) > 0 { + go SetStatusOnline(c.Session.UserId, c.Session.Id, false) + } return &WebConn{ Send: make(chan model.WebSocketMessage, 256), @@ -53,7 +56,9 @@ func (c *WebConn) readPump() { c.WebSocket.SetReadDeadline(time.Now().Add(PONG_WAIT)) c.WebSocket.SetPongHandler(func(string) error { c.WebSocket.SetReadDeadline(time.Now().Add(PONG_WAIT)) - go SetStatusAwayIfNeeded(c.UserId, false) + if c.isAuthenticated() { + go SetStatusAwayIfNeeded(c.UserId, false) + } return nil }) @@ -64,7 +69,7 @@ func (c *WebConn) readPump() { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { l4g.Debug(fmt.Sprintf("websocket.read: client side closed socket userId=%v", c.UserId)) } else { - l4g.Debug(fmt.Sprintf("websocket.read: cannot read, closing websocket for userId=%v error=%v", c.UserId, err.Error())) + l4g.Debug(fmt.Sprintf("websocket.read: closing websocket for userId=%v error=%v", c.UserId, err.Error())) } return @@ -76,9 +81,11 @@ func (c *WebConn) readPump() { func (c *WebConn) writePump() { ticker := time.NewTicker(PING_PERIOD) + authTicker := time.NewTicker(AUTH_TIMEOUT) defer func() { ticker.Stop() + authTicker.Stop() c.WebSocket.Close() }() @@ -97,7 +104,7 @@ func (c *WebConn) writePump() { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { l4g.Debug(fmt.Sprintf("websocket.send: client side closed socket userId=%v", c.UserId)) } else { - l4g.Debug(fmt.Sprintf("websocket.send: cannot send, closing websocket for userId=%v, error=%v", c.UserId, err.Error())) + l4g.Debug(fmt.Sprintf("websocket.send: closing websocket for userId=%v, error=%v", c.UserId, err.Error())) } return @@ -110,11 +117,18 @@ func (c *WebConn) writePump() { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { l4g.Debug(fmt.Sprintf("websocket.ticker: client side closed socket userId=%v", c.UserId)) } else { - l4g.Debug(fmt.Sprintf("websocket.ticker: cannot read, closing websocket for userId=%v error=%v", c.UserId, err.Error())) + l4g.Debug(fmt.Sprintf("websocket.ticker: closing websocket for userId=%v error=%v", c.UserId, err.Error())) } return } + + case <-authTicker.C: + if c.SessionToken == "" { + l4g.Debug(fmt.Sprintf("websocket.authTicker: did not authenticate ip=%v", c.WebSocket.RemoteAddr())) + return + } + authTicker.Stop() } } } @@ -122,10 +136,18 @@ func (c *WebConn) writePump() { func (webCon *WebConn) InvalidateCache() { webCon.AllChannelMembers = nil webCon.LastAllChannelMembersTime = 0 +} +func (webCon *WebConn) isAuthenticated() bool { + return webCon.SessionToken != "" } func (webCon *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool { + // IMPORTANT: Do not send event if WebConn does not have a session + if !webCon.isAuthenticated() { + return false + } + // If the event is destined to a specific user if len(msg.Broadcast.UserId) > 0 && webCon.UserId != msg.Broadcast.UserId { return false diff --git a/api/web_hub.go b/api/web_hub.go index 23c01eb1b..dfbdf3838 100644 --- a/api/web_hub.go +++ b/api/web_hub.go @@ -156,6 +156,10 @@ func (h *Hub) Start() { close(webCon.Send) } + if len(userId) == 0 { + continue + } + found := false for webCon := range h.connections { if userId == webCon.UserId { diff --git a/api/websocket.go b/api/websocket.go index 34d95f705..1c3277497 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -17,7 +17,7 @@ const ( func InitWebSocket() { l4g.Debug(utils.T("api.web_socket.init.debug")) - BaseRoutes.Users.Handle("/websocket", ApiUserRequiredTrustRequester(connect)).Methods("GET") + BaseRoutes.Users.Handle("/websocket", ApiAppHandlerTrustRequester(connect)).Methods("GET") HubStart() } diff --git a/api/websocket_router.go b/api/websocket_router.go index 34b576464..bdbd9f4d9 100644 --- a/api/websocket_router.go +++ b/api/websocket_router.go @@ -37,6 +37,37 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque return } + if r.Action == model.WEBSOCKET_AUTHENTICATION_CHALLENGE { + token, ok := r.Data["token"].(string) + if !ok { + conn.WebSocket.Close() + return + } + + session := GetSession(token) + + if session == nil || session.IsExpired() { + conn.WebSocket.Close() + } else { + go SetStatusOnline(session.UserId, session.Id, false) + + conn.SessionToken = session.Token + conn.UserId = session.UserId + + resp := model.NewWebSocketResponse(model.STATUS_OK, r.Seq, nil) + resp.DoPreComputeJson() + conn.Send <- resp + } + + return + } + + if conn.SessionToken == "" { + err := model.NewLocAppError("ServeWebSocket", "api.web_socket_router.not_authenticated.app_error", nil, "") + wr.ReturnWebSocketError(conn, r, err) + return + } + var handler *webSocketHandler if h, ok := wr.handlers[r.Action]; !ok { err := model.NewLocAppError("ServeWebSocket", "api.web_socket_router.bad_action.app_error", nil, "") diff --git a/api/websocket_test.go b/api/websocket_test.go index b7ca4b691..144c1a39b 100644 --- a/api/websocket_test.go +++ b/api/websocket_test.go @@ -4,12 +4,116 @@ package api import ( + "encoding/json" + "net/http" "testing" "time" + "github.com/gorilla/websocket" "github.com/mattermost/platform/model" ) +func TestWebSocketAuthentication(t *testing.T) { + th := Setup().InitBasic() + WebSocketClient, err := th.CreateWebSocketClient() + if err != nil { + t.Fatal(err) + } + WebSocketClient.Listen() + + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + + WebSocketClient.SendMessage("ping", nil) + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Data["text"].(string) != "pong" { + t.Fatal("wrong response") + } + + WebSocketClient.Close() + + authToken := WebSocketClient.AuthToken + WebSocketClient.AuthToken = "junk" + if err := WebSocketClient.Connect(); err != nil { + t.Fatal(err) + } + WebSocketClient.Listen() + + if resp := <-WebSocketClient.ResponseChannel; resp != nil { + t.Fatal("should have closed") + } + + WebSocketClient.Close() + + if conn, _, err := websocket.DefaultDialer.Dial(WebSocketClient.ApiUrl+"/users/websocket", nil); err != nil { + t.Fatal("should have connected") + } else { + req := &model.WebSocketRequest{} + req.Seq = 1 + req.Action = "ping" + conn.WriteJSON(req) + + closedAutomatically := false + hitNotAuthedError := false + + go func() { + time.Sleep(10 * time.Second) + conn.Close() + + if !closedAutomatically { + t.Fatal("should have closed automatically in 5 seconds") + } + }() + + for { + if _, rawMsg, err := conn.ReadMessage(); err != nil { + closedAutomatically = true + conn.Close() + break + } else { + var response model.WebSocketResponse + if err := json.Unmarshal(rawMsg, &response); err != nil && !response.IsValid() { + t.Fatal("should not have failed") + } else { + if response.Error == nil || response.Error.Id != "api.web_socket_router.not_authenticated.app_error" { + t.Log(response.Error.Id) + t.Fatal("wrong error") + continue + } + + hitNotAuthedError = true + } + } + } + + if !hitNotAuthedError { + t.Fatal("should have received a not authenticated response") + } + } + + header := http.Header{} + header.Set(model.HEADER_AUTH, "BEARER "+authToken) + if conn, _, err := websocket.DefaultDialer.Dial(WebSocketClient.ApiUrl+"/users/websocket", header); err != nil { + t.Fatal("should have connected") + } else { + if _, rawMsg, err := conn.ReadMessage(); err != nil { + t.Fatal("should not have closed automatically") + } else { + var event model.WebSocketEvent + if err := json.Unmarshal(rawMsg, &event); err != nil && !event.IsValid() { + t.Fatal("should not have failed") + } else if event.Event != model.WEBSOCKET_EVENT_HELLO { + t.Log(event.ToJson()) + t.Fatal("should have helloed") + } + } + + conn.Close() + } +} + func TestWebSocket(t *testing.T) { th := Setup().InitBasic() WebSocketClient, err := th.CreateWebSocketClient() @@ -29,6 +133,9 @@ func TestWebSocket(t *testing.T) { WebSocketClient.Listen() time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } WebSocketClient.SendMessage("ping", nil) time.Sleep(300 * time.Millisecond) @@ -78,6 +185,11 @@ func TestWebSocketEvent(t *testing.T) { WebSocketClient.Listen() + time.Sleep(300 * time.Millisecond) + if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK { + t.Fatal("should have responded OK to authentication challenge") + } + omitUser := make(map[string]bool, 1) omitUser["somerandomid"] = true evt1 := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_TYPING, "", th.BasicChannel.Id, "", omitUser) diff --git a/model/websocket_client.go b/model/websocket_client.go index a4983e385..453ae49b7 100644 --- a/model/websocket_client.go +++ b/model/websocket_client.go @@ -6,7 +6,6 @@ package model import ( "encoding/json" "github.com/gorilla/websocket" - "net/http" ) type WebSocketClient struct { @@ -23,14 +22,12 @@ type WebSocketClient struct { // NewWebSocketClient constructs a new WebSocket client with convienence // methods for talking to the server. func NewWebSocketClient(url, authToken string) (*WebSocketClient, *AppError) { - header := http.Header{} - header.Set(HEADER_AUTH, "BEARER "+authToken) - conn, _, err := websocket.DefaultDialer.Dial(url+API_URL_SUFFIX+"/users/websocket", header) + conn, _, err := websocket.DefaultDialer.Dial(url+API_URL_SUFFIX+"/users/websocket", nil) if err != nil { return nil, NewLocAppError("NewWebSocketClient", "model.websocket_client.connect_fail.app_error", nil, err.Error()) } - return &WebSocketClient{ + client := &WebSocketClient{ url, url + API_URL_SUFFIX, conn, @@ -39,19 +36,25 @@ func NewWebSocketClient(url, authToken string) (*WebSocketClient, *AppError) { make(chan *WebSocketEvent, 100), make(chan *WebSocketResponse, 100), nil, - }, nil + } + + client.SendMessage(WEBSOCKET_AUTHENTICATION_CHALLENGE, map[string]interface{}{"token": authToken}) + + return client, nil } func (wsc *WebSocketClient) Connect() *AppError { - header := http.Header{} - header.Set(HEADER_AUTH, "BEARER "+wsc.AuthToken) - var err error - wsc.Conn, _, err = websocket.DefaultDialer.Dial(wsc.ApiUrl+"/users/websocket", header) + wsc.Conn, _, err = websocket.DefaultDialer.Dial(wsc.ApiUrl+"/users/websocket", nil) if err != nil { return NewLocAppError("NewWebSocketClient", "model.websocket_client.connect_fail.app_error", nil, err.Error()) } + wsc.EventChannel = make(chan *WebSocketEvent, 100) + wsc.ResponseChannel = make(chan *WebSocketResponse, 100) + + wsc.SendMessage(WEBSOCKET_AUTHENTICATION_CHALLENGE, map[string]interface{}{"token": wsc.AuthToken}) + return nil } @@ -89,6 +92,7 @@ func (wsc *WebSocketClient) Listen() { wsc.ResponseChannel <- &response continue } + } }() } diff --git a/model/websocket_message.go b/model/websocket_message.go index df5cf3b81..5eb02642e 100644 --- a/model/websocket_message.go +++ b/model/websocket_message.go @@ -26,6 +26,7 @@ const ( WEBSOCKET_EVENT_STATUS_CHANGE = "status_change" WEBSOCKET_EVENT_HELLO = "hello" WEBSOCKET_EVENT_WEBRTC = "webrtc" + WEBSOCKET_AUTHENTICATION_CHALLENGE = "authentication_challenge" ) type WebSocketMessage interface { diff --git a/webapp/client/websocket_client.jsx b/webapp/client/websocket_client.jsx index 035e30be5..760c62b59 100644 --- a/webapp/client/websocket_client.jsx +++ b/webapp/client/websocket_client.jsx @@ -18,7 +18,7 @@ export default class WebSocketClient { this.closeCallback = null; } - initialize(connectionUrl) { + initialize(connectionUrl, token) { if (this.conn) { return; } @@ -30,6 +30,10 @@ export default class WebSocketClient { this.conn = new WebSocket(connectionUrl); this.conn.onopen = () => { + if (token) { + this.sendMessage('authentication_challenge', {token}); + } + if (this.connectFailCount > 0) { console.log('websocket re-established connection'); //eslint-disable-line no-console if (this.reconnectCallback) { @@ -68,7 +72,7 @@ export default class WebSocketClient { setTimeout( () => { - this.initialize(connectionUrl); + this.initialize(connectionUrl, token); }, retryTime ); @@ -152,12 +156,12 @@ export default class WebSocketClient { } } - userTyping(channelId, parentId) { + userTyping(channelId, parentId, callback) { const data = {}; data.channel_id = channelId; data.parent_id = parentId; - this.sendMessage('user_typing', data); + this.sendMessage('user_typing', data, callback); } getStatuses(callback) { diff --git a/webapp/tests/client_websocket.test.jsx b/webapp/tests/client_websocket.test.jsx new file mode 100644 index 000000000..6535610e3 --- /dev/null +++ b/webapp/tests/client_websocket.test.jsx @@ -0,0 +1,49 @@ +// Copyright (c) 2016 Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. +/* +var assert = require('assert'); +import TestHelper from './test_helper.jsx'; + +describe('Client.WebSocket', function() { + this.timeout(10000); + + it('WebSocket.getStatusesByIds', function(done) { + TestHelper.initBasic(() => { + TestHelper.basicWebSocketClient().getStatusesByIds( + [TestHelper.basicUser().id], + function(resp) { + TestHelper.basicWebSocketClient().close(); + assert.equal(resp.data[TestHelper.basicUser().id], 'online'); + done(); + } + ); + }, true); + }); + + it('WebSocket.getStatuses', function(done) { + TestHelper.initBasic(() => { + TestHelper.basicWebSocketClient().getStatuses( + function(resp) { + TestHelper.basicWebSocketClient().close(); + assert.equal(resp.data != null, true); + done(); + } + ); + }, true); + }); + + it('WebSocket.userTyping', function(done) { + TestHelper.initBasic(() => { + TestHelper.basicWebSocketClient().userTyping( + TestHelper.basicChannel().id, + '', + function(resp) { + TestHelper.basicWebSocketClient().close(); + assert.equal(resp.status, 'OK'); + done(); + } + ); + }, true); + }); +});*/ + diff --git a/webapp/tests/test_helper.jsx b/webapp/tests/test_helper.jsx index 41d0c15ba..310714e30 100644 --- a/webapp/tests/test_helper.jsx +++ b/webapp/tests/test_helper.jsx @@ -2,13 +2,20 @@ // See License.txt for license information. import Client from 'client/client.jsx'; +import WebSocketClient from 'client/websocket_client.jsx'; import jqd from 'jquery-deferred'; +var HEADER_TOKEN = 'token'; + class TestHelperClass { basicClient = () => { return this.basicc; } + basicWebSocketClient = () => { + return this.basicwsc; + } + basicTeam = () => { return this.basict; } @@ -53,6 +60,12 @@ class TestHelperClass { return c; } + createWebSocketClient(token) { + var ws = new WebSocketClient(); + ws.initialize('http://localhost:8065/api/v3/users/websocket', token); + return ws; + } + fakeEmail = () => { return 'success' + this.generateId() + '@simulator.amazonses.com'; } @@ -90,7 +103,7 @@ class TestHelperClass { return post; } - initBasic = (callback) => { + initBasic = (callback, connectWS) => { this.basicc = this.createClient(); var d1 = jqd.Deferred(); @@ -122,7 +135,10 @@ class TestHelperClass { rteamSignup.user.email, password, null, - function() { + function(data, res) { + if (connectWS) { + outer.basicwsc = outer.createWebSocketClient(res.header[HEADER_TOKEN]); + } outer.basicClient().useHeaderToken(); var channel = outer.fakeChannel(); channel.team_id = outer.basicTeam().id; |