From 47e6a33a4505e13ba4edf37ff1f8fbdadb279ee3 Mon Sep 17 00:00:00 2001 From: JoramWilander Date: Wed, 16 Sep 2015 15:49:12 -0400 Subject: Implement OAuth2 service provider functionality. --- api/api.go | 1 + api/api_test.go | 2 +- api/channel_test.go | 4 +- api/command.go | 8 +-- api/context.go | 77 ++++++++++++++---------- api/oauth.go | 165 ++++++++++++++++++++++++++++++++++++++++++++++++++++ api/oauth_test.go | 157 +++++++++++++++++++++++++++++++++++++++++++++++++ api/post_test.go | 4 +- api/team_test.go | 2 +- api/user.go | 42 ++++++------- api/user_test.go | 17 ++---- 11 files changed, 408 insertions(+), 71 deletions(-) create mode 100644 api/oauth.go create mode 100644 api/oauth_test.go (limited to 'api') diff --git a/api/api.go b/api/api.go index 8203b07a6..c8f97c5af 100644 --- a/api/api.go +++ b/api/api.go @@ -43,6 +43,7 @@ func InitApi() { InitFile(r) InitCommand(r) InitAdmin(r) + InitOAuth(r) templatesDir := utils.FindDir("api/templates") l4g.Debug("Parsing server templates at %v", templatesDir) diff --git a/api/api_test.go b/api/api_test.go index 0c2e57891..642db581e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -17,7 +17,7 @@ func Setup() { NewServer() StartServer() InitApi() - Client = model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port + "/api/v1") + Client = model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port) } } diff --git a/api/channel_test.go b/api/channel_test.go index d65aff66c..7e9267192 100644 --- a/api/channel_test.go +++ b/api/channel_test.go @@ -62,7 +62,7 @@ func TestCreateChannel(t *testing.T) { } } - if _, err := Client.DoPost("/channels/create", "garbage"); err == nil { + if _, err := Client.DoApiPost("/channels/create", "garbage"); err == nil { t.Fatal("should have been an error") } @@ -627,7 +627,7 @@ func TestGetChannelExtraInfo(t *testing.T) { currentEtag = cache_result.Etag } - Client2 := model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port + "/api/v1") + Client2 := model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port) user2 := &model.User{TeamId: team.Id, Email: model.NewId() + "tester2@test.com", Nickname: "Tester 2", Password: "pwd"} user2 = Client2.Must(Client2.CreateUser(user2, "")).Data.(*model.User) diff --git a/api/command.go b/api/command.go index 2919e93a0..be1d3229b 100644 --- a/api/command.go +++ b/api/command.go @@ -315,7 +315,7 @@ func loadTestSetupCommand(c *Context, command *model.Command) bool { numPosts, _ = strconv.Atoi(tokens[numArgs+2]) } } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) if doTeams { if err := CreateBasicUser(client); err != nil { @@ -375,7 +375,7 @@ func loadTestUsersCommand(c *Context, command *model.Command) bool { if err == false { usersr = utils.Range{10, 15} } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) userCreator := NewAutoUserCreator(client, c.Session.TeamId) userCreator.Fuzzy = doFuzz userCreator.CreateTestUsers(usersr) @@ -405,7 +405,7 @@ func loadTestChannelsCommand(c *Context, command *model.Command) bool { if err == false { channelsr = utils.Range{20, 30} } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) client.MockSession(c.Session.Id) channelCreator := NewAutoChannelCreator(client, c.Session.TeamId) channelCreator.Fuzzy = doFuzz @@ -457,7 +457,7 @@ func loadTestPostsCommand(c *Context, command *model.Command) bool { } } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) client.MockSession(c.Session.Id) testPoster := NewAutoPostCreator(client, command.ChannelId) testPoster.Fuzzy = doFuzz diff --git a/api/context.go b/api/context.go index 5dcdfaf96..b1b4d2d10 100644 --- a/api/context.go +++ b/api/context.go @@ -80,9 +80,36 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.RequestId = model.NewId() c.IpAddress = GetIpAddress(r) + token := "" + isTokenFromQueryString := false + + // Attempt to parse token out of the header + authHeader := r.Header.Get(model.HEADER_AUTH) + if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER { + // Default session token + token = authHeader[7:] + + } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN { + // OAuth token + token = authHeader[6:] + } + + // Attempt to parse the token from the cookie + if len(token) == 0 { + if cookie, err := r.Cookie(model.SESSION_TOKEN); err == nil { + token = cookie.Value + } + } + + // Attempt to parse token out of the query string + if len(token) == 0 { + token = r.URL.Query().Get("access_token") + isTokenFromQueryString = true + } + protocol := "http" - // if the request came from the ELB then assume this is produciton + // If the request came from the ELB then assume this is produciton // and redirect all http requests to https if utils.Cfg.ServiceSettings.UseSSL { forwardProto := r.Header.Get(model.HEADER_FORWARDED_PROTO) @@ -105,36 +132,19 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("Content-Security-Policy", "frame-ancestors none") } else { - // All api response bodies will be JSON formatted + // All api response bodies will be JSON formatted by default w.Header().Set("Content-Type", "application/json") } - sessionId := "" - - // attempt to parse the session token from the header - if ah := r.Header.Get(model.HEADER_AUTH); ah != "" { - if len(ah) > 6 && strings.ToUpper(ah[0:6]) == "BEARER" { - sessionId = ah[7:] - } - } - - // attempt to parse the session token from the cookie - if sessionId == "" { - if cookie, err := r.Cookie(model.SESSION_TOKEN); err == nil { - sessionId = cookie.Value - } - } - - if sessionId != "" { - + if len(token) != 0 { var session *model.Session - if ts, ok := sessionCache.Get(sessionId); ok { + if ts, ok := sessionCache.Get(token); ok { session = ts.(*model.Session) } if session == nil { - if sessionResult := <-Srv.Store.Session().Get(sessionId); sessionResult.Err != nil { - c.LogError(model.NewAppError("ServeHTTP", "Invalid session", "id="+sessionId+", err="+sessionResult.Err.DetailedError)) + if sessionResult := <-Srv.Store.Session().Get(token); sessionResult.Err != nil { + c.LogError(model.NewAppError("ServeHTTP", "Invalid session", "token="+token+", err="+sessionResult.Err.DetailedError)) } else { session = sessionResult.Data.(*model.Session) } @@ -142,7 +152,10 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if session == nil || session.IsExpired() { c.RemoveSessionCookie(w) - c.Err = model.NewAppError("ServeHTTP", "Invalid or expired session, please login again.", "id="+sessionId) + c.Err = model.NewAppError("ServeHTTP", "Invalid or expired session, please login again.", "token="+token) + c.Err.StatusCode = http.StatusUnauthorized + } else if !session.IsOAuth && isTokenFromQueryString { + c.Err = model.NewAppError("ServeHTTP", "Session is not OAuth but token was provided in the query string", "token="+token) c.Err.StatusCode = http.StatusUnauthorized } else { c.Session = *session @@ -166,10 +179,10 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.SystemAdminRequired() } - if c.Err == nil && h.isUserActivity && sessionId != "" && len(c.Session.UserId) > 0 { + if c.Err == nil && h.isUserActivity && token != "" && len(c.Session.UserId) > 0 { go func() { - if err := (<-Srv.Store.User().UpdateUserAndSessionActivity(c.Session.UserId, sessionId, model.GetMillis())).Err; err != nil { - l4g.Error("Failed to update LastActivityAt for user_id=%v and session_id=%v, err=%v", c.Session.UserId, sessionId, err) + if err := (<-Srv.Store.User().UpdateUserAndSessionActivity(c.Session.UserId, c.Session.Id, model.GetMillis())).Err; err != nil { + l4g.Error("Failed to update LastActivityAt for user_id=%v and session_id=%v, err=%v", c.Session.UserId, c.Session.Id, err) } }() } @@ -197,7 +210,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (c *Context) LogAudit(extraInfo string) { - audit := &model.Audit{UserId: c.Session.UserId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.AltId} + audit := &model.Audit{UserId: c.Session.UserId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.Id} if r := <-Srv.Store.Audit().Save(audit); r.Err != nil { c.LogError(r.Err) } @@ -209,7 +222,7 @@ func (c *Context) LogAuditWithUserId(userId, extraInfo string) { extraInfo = strings.TrimSpace(extraInfo + " session_user=" + c.Session.UserId) } - audit := &model.Audit{UserId: userId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.AltId} + audit := &model.Audit{UserId: userId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.Id} if r := <-Srv.Store.Audit().Save(audit); r.Err != nil { c.LogError(r.Err) } @@ -315,7 +328,7 @@ func (c *Context) IsTeamAdmin(userId string) bool { func (c *Context) RemoveSessionCookie(w http.ResponseWriter) { - sessionCache.Remove(c.Session.Id) + sessionCache.Remove(c.Session.Token) cookie := &http.Cookie{ Name: model.SESSION_TOKEN, @@ -471,3 +484,7 @@ func Handle404(w http.ResponseWriter, r *http.Request) { l4g.Error("%v: code=404 ip=%v", r.URL.Path, GetIpAddress(r)) RenderWebError(err, w, r) } + +func AddSessionToCache(session *model.Session) { + sessionCache.Add(session.Token, session) +} diff --git a/api/oauth.go b/api/oauth.go new file mode 100644 index 000000000..26c3c5da8 --- /dev/null +++ b/api/oauth.go @@ -0,0 +1,165 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package api + +import ( + l4g "code.google.com/p/log4go" + "fmt" + "github.com/gorilla/mux" + "github.com/mattermost/platform/model" + "github.com/mattermost/platform/utils" + "net/http" + "net/url" +) + +func InitOAuth(r *mux.Router) { + l4g.Debug("Initializing oauth api routes") + + sr := r.PathPrefix("/oauth").Subrouter() + + sr.Handle("/register", ApiUserRequired(registerOAuthApp)).Methods("POST") + sr.Handle("/allow", ApiUserRequired(allowOAuth)).Methods("GET") +} + +func registerOAuthApp(c *Context, w http.ResponseWriter, r *http.Request) { + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + c.Err = model.NewAppError("registerOAuthApp", "The system admin has turned off OAuth service providing.", "") + c.Err.StatusCode = http.StatusNotImplemented + return + } + + app := model.OAuthAppFromJson(r.Body) + + if app == nil { + c.SetInvalidParam("registerOAuthApp", "app") + return + } + + secret := model.NewId() + + app.ClientSecret = secret + app.CreatorId = c.Session.UserId + + if result := <-Srv.Store.OAuth().SaveApp(app); result.Err != nil { + c.Err = result.Err + return + } else { + app = result.Data.(*model.OAuthApp) + app.ClientSecret = secret + + c.LogAudit("client_id=" + app.Id) + + w.Write([]byte(app.ToJson())) + return + } + +} + +func allowOAuth(c *Context, w http.ResponseWriter, r *http.Request) { + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + c.Err = model.NewAppError("allowOAuth", "The system admin has turned off OAuth service providing.", "") + c.Err.StatusCode = http.StatusNotImplemented + return + } + + c.LogAudit("attempt") + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + responseData := map[string]string{} + + responseType := r.URL.Query().Get("response_type") + if len(responseType) == 0 { + c.Err = model.NewAppError("allowOAuth", "invalid_request: Bad response_type", "") + return + } + + clientId := r.URL.Query().Get("client_id") + if len(clientId) != 26 { + c.Err = model.NewAppError("allowOAuth", "invalid_request: Bad client_id", "") + return + } + + redirectUri := r.URL.Query().Get("redirect_uri") + if len(redirectUri) == 0 { + c.Err = model.NewAppError("allowOAuth", "invalid_request: Missing or bad redirect_uri", "") + return + } + + scope := r.URL.Query().Get("scope") + state := r.URL.Query().Get("state") + + var app *model.OAuthApp + if result := <-Srv.Store.OAuth().GetApp(clientId); result.Err != nil { + c.Err = model.NewAppError("allowOAuth", "server_error: Error accessing the database", "") + return + } else { + app = result.Data.(*model.OAuthApp) + } + + if !app.IsValidRedirectURL(redirectUri) { + c.LogAudit("fail - redirect_uri did not match registered callback") + c.Err = model.NewAppError("allowOAuth", "invalid_request: Supplied redirect_uri did not match registered callback_url", "") + return + } + + if responseType != model.AUTHCODE_RESPONSE_TYPE { + responseData["redirect"] = redirectUri + "?error=unsupported_response_type&state=" + state + w.Write([]byte(model.MapToJson(responseData))) + return + } + + authData := &model.AuthData{UserId: c.Session.UserId, ClientId: clientId, CreateAt: model.GetMillis(), RedirectUri: redirectUri, State: state, Scope: scope} + authData.Code = model.HashPassword(fmt.Sprintf("%v:%v:%v:%v", clientId, redirectUri, authData.CreateAt, c.Session.UserId)) + + if result := <-Srv.Store.OAuth().SaveAuthData(authData); result.Err != nil { + responseData["redirect"] = redirectUri + "?error=server_error&state=" + state + w.Write([]byte(model.MapToJson(responseData))) + return + } + + c.LogAudit("success") + + responseData["redirect"] = redirectUri + "?code=" + url.QueryEscape(authData.Code) + "&state=" + url.QueryEscape(authData.State) + + w.Write([]byte(model.MapToJson(responseData))) +} + +func RevokeAccessToken(token string) *model.AppError { + + schan := Srv.Store.Session().Remove(token) + sessionCache.Remove(token) + + var accessData *model.AccessData + if result := <-Srv.Store.OAuth().GetAccessData(token); result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error getting access token from DB before deletion", "") + } else { + accessData = result.Data.(*model.AccessData) + } + + tchan := Srv.Store.OAuth().RemoveAccessData(token) + cchan := Srv.Store.OAuth().RemoveAuthData(accessData.AuthCode) + + if result := <-tchan; result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error deleting access token from DB", "") + } + + if result := <-cchan; result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error deleting authorization code from DB", "") + } + + if result := <-schan; result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error deleting session from DB", "") + } + + return nil +} + +func GetAuthData(code string) *model.AuthData { + if result := <-Srv.Store.OAuth().GetAuthData(code); result.Err != nil { + l4g.Error("Couldn't find auth code for code=%s", code) + return nil + } else { + return result.Data.(*model.AuthData) + } +} diff --git a/api/oauth_test.go b/api/oauth_test.go new file mode 100644 index 000000000..18db49bc5 --- /dev/null +++ b/api/oauth_test.go @@ -0,0 +1,157 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package api + +import ( + "github.com/mattermost/platform/model" + "github.com/mattermost/platform/store" + "github.com/mattermost/platform/utils" + "net/url" + "strings" + "testing" +) + +func TestRegisterApp(t *testing.T) { + Setup() + + team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewId() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN} + rteam, _ := Client.CreateTeam(&team) + + user := model.User{TeamId: rteam.Data.(*model.Team).Id, Email: strings.ToLower(model.NewId()) + "corey@test.com", Password: "pwd"} + ruser := Client.Must(Client.CreateUser(&user, "")).Data.(*model.User) + store.Must(Srv.Store.User().VerifyEmail(ruser.Id)) + + app := &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("should have failed - oauth providing turned off") + } + + } else { + + Client.Logout() + + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("not logged in - should have failed") + } + + Client.Must(Client.LoginById(ruser.Id, "pwd")) + + if result, err := Client.RegisterApp(app); err != nil { + t.Fatal(err) + } else { + rapp := result.Data.(*model.OAuthApp) + if len(rapp.Id) != 26 { + t.Fatal("clientid didn't return properly") + } + if len(rapp.ClientSecret) != 26 { + t.Fatal("client secret didn't return properly") + } + } + + app = &model.OAuthApp{Name: "", Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("missing name - should have failed") + } + + app = &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("missing homepage - should have failed") + } + + app = &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{}} + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("missing callback url - should have failed") + } + } +} + +func TestAllowOAuth(t *testing.T) { + Setup() + + team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewId() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN} + rteam, _ := Client.CreateTeam(&team) + + user := model.User{TeamId: rteam.Data.(*model.Team).Id, Email: strings.ToLower(model.NewId()) + "corey@test.com", Password: "pwd"} + ruser := Client.Must(Client.CreateUser(&user, "")).Data.(*model.User) + store.Must(Srv.Store.User().VerifyEmail(ruser.Id)) + + app := &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + + Client.Must(Client.LoginById(ruser.Id, "pwd")) + + state := "123" + + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, "12345678901234567890123456", app.CallbackUrls[0], "all", state); err == nil { + t.Fatal("should have failed - oauth service providing turned off") + } + } else { + app = Client.Must(Client.RegisterApp(app)).Data.(*model.OAuthApp) + + if result, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, app.CallbackUrls[0], "all", state); err != nil { + t.Fatal(err) + } else { + redirect := result.Data.(map[string]string)["redirect"] + if len(redirect) == 0 { + t.Fatal("redirect url should be set") + } + + ru, _ := url.Parse(redirect) + if ru == nil { + t.Fatal("redirect url unparseable") + } else { + if len(ru.Query().Get("code")) == 0 { + t.Fatal("authorization code not returned") + } + if ru.Query().Get("state") != state { + t.Fatal("returned state doesn't match") + } + } + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, "", "all", state); err == nil { + t.Fatal("should have failed - no redirect_url given") + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, "", "", state); err == nil { + t.Fatal("should have failed - no redirect_url given") + } + + if result, err := Client.AllowOAuth("junk", app.Id, app.CallbackUrls[0], "all", state); err != nil { + t.Fatal(err) + } else { + redirect := result.Data.(map[string]string)["redirect"] + if len(redirect) == 0 { + t.Fatal("redirect url should be set") + } + + ru, _ := url.Parse(redirect) + if ru == nil { + t.Fatal("redirect url unparseable") + } else { + if ru.Query().Get("error") != "unsupported_response_type" { + t.Fatal("wrong error returned") + } + if ru.Query().Get("state") != state { + t.Fatal("returned state doesn't match") + } + } + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, "", app.CallbackUrls[0], "all", state); err == nil { + t.Fatal("should have failed - empty client id") + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, "junk", app.CallbackUrls[0], "all", state); err == nil { + t.Fatal("should have failed - bad client id") + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, "https://somewhereelse.com", "all", state); err == nil { + t.Fatal("should have failed - redirect uri host does not match app host") + } + } +} diff --git a/api/post_test.go b/api/post_test.go index 85d92de3a..4cccfd62a 100644 --- a/api/post_test.go +++ b/api/post_test.go @@ -118,7 +118,7 @@ func TestCreatePost(t *testing.T) { t.Fatal("Should have been forbidden") } - if _, err = Client.DoPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { + if _, err = Client.DoApiPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { t.Fatal("should have been an error") } } @@ -203,7 +203,7 @@ func TestCreateValetPost(t *testing.T) { t.Fatal("Should have been forbidden") } - if _, err = Client.DoPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { + if _, err = Client.DoApiPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { t.Fatal("should have been an error") } } else { diff --git a/api/team_test.go b/api/team_test.go index 2723eff57..4f1b9e5f0 100644 --- a/api/team_test.go +++ b/api/team_test.go @@ -103,7 +103,7 @@ func TestCreateTeam(t *testing.T) { } } - if _, err := Client.DoPost("/teams/create", "garbage"); err == nil { + if _, err := Client.DoApiPost("/teams/create", "garbage"); err == nil { t.Fatal("should have been an error") } } diff --git a/api/user.go b/api/user.go index cdd9a68be..b42d156ae 100644 --- a/api/user.go +++ b/api/user.go @@ -336,7 +336,7 @@ func Login(c *Context, w http.ResponseWriter, r *http.Request, user *model.User, return } - session := &model.Session{UserId: user.Id, TeamId: user.TeamId, Roles: user.Roles, DeviceId: deviceId} + session := &model.Session{UserId: user.Id, TeamId: user.TeamId, Roles: user.Roles, DeviceId: deviceId, IsOAuth: false} maxAge := model.SESSION_TIME_WEB_IN_SECS @@ -378,13 +378,13 @@ func Login(c *Context, w http.ResponseWriter, r *http.Request, user *model.User, return } else { session = result.Data.(*model.Session) - sessionCache.Add(session.Id, session) + AddSessionToCache(session) } - w.Header().Set(model.HEADER_TOKEN, session.Id) + w.Header().Set(model.HEADER_TOKEN, session.Token) sessionCookie := &http.Cookie{ Name: model.SESSION_TOKEN, - Value: session.Id, + Value: session.Token, Path: "/", MaxAge: maxAge, HttpOnly: true, @@ -430,25 +430,27 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) { func revokeSession(c *Context, w http.ResponseWriter, r *http.Request) { props := model.MapFromJson(r.Body) - altId := props["id"] + id := props["id"] - if result := <-Srv.Store.Session().GetSessions(c.Session.UserId); result.Err != nil { + if result := <-Srv.Store.Session().Get(id); result.Err != nil { c.Err = result.Err return } else { - sessions := result.Data.([]*model.Session) + session := result.Data.(*model.Session) - for _, session := range sessions { - if session.AltId == altId { - c.LogAudit("session_id=" + session.AltId) - sessionCache.Remove(session.Id) - if result := <-Srv.Store.Session().Remove(session.Id); result.Err != nil { - c.Err = result.Err - return - } else { - w.Write([]byte(model.MapToJson(props))) - return - } + c.LogAudit("session_id=" + session.Id) + + if session.IsOAuth { + RevokeAccessToken(session.Token) + } else { + sessionCache.Remove(session.Token) + + if result := <-Srv.Store.Session().Remove(session.Id); result.Err != nil { + c.Err = result.Err + return + } else { + w.Write([]byte(model.MapToJson(props))) + return } } } @@ -462,8 +464,8 @@ func RevokeAllSession(c *Context, userId string) { sessions := result.Data.([]*model.Session) for _, session := range sessions { - c.LogAuditWithUserId(userId, "session_id="+session.AltId) - sessionCache.Remove(session.Id) + c.LogAuditWithUserId(userId, "session_id="+session.Id) + sessionCache.Remove(session.Token) if result := <-Srv.Store.Session().Remove(session.Id); result.Err != nil { c.Err = result.Err return diff --git a/api/user_test.go b/api/user_test.go index fe5a4a27f..986365bd0 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -68,7 +68,7 @@ func TestCreateUser(t *testing.T) { } } - if _, err := Client.DoPost("/users/create", "garbage"); err == nil { + if _, err := Client.DoApiPost("/users/create", "garbage"); err == nil { t.Fatal("should have been an error") } } @@ -190,11 +190,11 @@ func TestSessions(t *testing.T) { for _, session := range sessions { if session.DeviceId == deviceId { - otherSession = session.AltId + otherSession = session.Id } - if len(session.Id) != 0 { - t.Fatal("shouldn't return sessions") + if len(session.Token) != 0 { + t.Fatal("shouldn't return session tokens") } } @@ -212,11 +212,6 @@ func TestSessions(t *testing.T) { if len(sessions2) != 1 { t.Fatal("invalid number of sessions") } - - if _, err := Client.RevokeSession(otherSession); err != nil { - t.Fatal(err) - } - } func TestGetUser(t *testing.T) { @@ -355,7 +350,7 @@ func TestUserCreateImage(t *testing.T) { Client.LoginByEmail(team.Name, user.Email, "pwd") - Client.DoGet("/users/"+user.Id+"/image", "", "") + Client.DoApiGet("/users/"+user.Id+"/image", "", "") if utils.IsS3Configured() && !utils.Cfg.ServiceSettings.UseLocalStorage { var auth aws.Auth @@ -453,7 +448,7 @@ func TestUserUploadProfileImage(t *testing.T) { t.Fatal(upErr) } - Client.DoGet("/users/"+user.Id+"/image", "", "") + Client.DoApiGet("/users/"+user.Id+"/image", "", "") if utils.IsS3Configured() && !utils.Cfg.ServiceSettings.UseLocalStorage { var auth aws.Auth -- cgit v1.2.3-1-g7c22