From ccb034382850b7e8ea924a4559e47ef44203155c Mon Sep 17 00:00:00 2001 From: Joram Wilander Date: Fri, 3 Feb 2017 09:30:57 -0500 Subject: Implement POST /users/ids endpoint for APIv4 (#5274) --- api/user.go | 8 ++--- api4/user.go | 20 ++++++++++++ api4/user_test.go | 32 ++++++++++++++++++ app/user.go | 10 ++++-- model/client4.go | 10 ++++++ store/sql_user_store.go | 12 +++---- store/sql_user_store_test.go | 78 ++++++++++++++++++++++++++++++++++---------- 7 files changed, 139 insertions(+), 31 deletions(-) diff --git a/api/user.go b/api/user.go index bfe2db14e..6f40388b2 100644 --- a/api/user.go +++ b/api/user.go @@ -1571,15 +1571,15 @@ func getProfilesByIds(c *Context, w http.ResponseWriter, r *http.Request) { return } - if profiles, err := app.GetUsersByIds(userIds); err != nil { + if profiles, err := app.GetUsersByIds(userIds, c.IsSystemAdmin()); err != nil { c.Err = err return } else { + profileMap := map[string]*model.User{} for _, p := range profiles { - sanitizeProfile(c, p) + profileMap[p.Id] = p } - - w.Write([]byte(model.UserMapToJson(profiles))) + w.Write([]byte(model.UserMapToJson(profileMap))) } } diff --git a/api4/user.go b/api4/user.go index f68d01d33..19d3446fb 100644 --- a/api4/user.go +++ b/api4/user.go @@ -16,6 +16,8 @@ func InitUser() { l4g.Debug(utils.T("api.user.init.debug")) BaseRoutes.Users.Handle("", ApiHandler(createUser)).Methods("POST") + BaseRoutes.Users.Handle("/ids", ApiSessionRequired(getUsersByIds)).Methods("POST") + BaseRoutes.User.Handle("", ApiSessionRequired(getUser)).Methods("GET") BaseRoutes.User.Handle("", ApiSessionRequired(updateUser)).Methods("PUT") BaseRoutes.User.Handle("/roles", ApiSessionRequired(updateUserRoles)).Methods("PUT") @@ -84,6 +86,24 @@ func getUser(c *Context, w http.ResponseWriter, r *http.Request) { } } +func getUsersByIds(c *Context, w http.ResponseWriter, r *http.Request) { + userIds := model.ArrayFromJson(r.Body) + + if len(userIds) == 0 { + c.SetInvalidParam("user_ids") + return + } + + // No permission check required + + if users, err := app.GetUsersByIds(userIds, c.IsSystemAdmin()); err != nil { + c.Err = err + return + } else { + w.Write([]byte(model.UserListToJson(users))) + } +} + func updateUser(c *Context, w http.ResponseWriter, r *http.Request) { c.RequireUserId() if c.Err != nil { diff --git a/api4/user_test.go b/api4/user_test.go index 501bb38e3..54aae4e49 100644 --- a/api4/user_test.go +++ b/api4/user_test.go @@ -131,6 +131,38 @@ func TestGetUser(t *testing.T) { } } +func TestGetUsersByIds(t *testing.T) { + th := Setup().InitBasic() + Client := th.Client + + users, resp := Client.GetUsersByIds([]string{th.BasicUser.Id}) + CheckNoError(t, resp) + + if users[0].Id != th.BasicUser.Id { + t.Fatal("returned wrong user") + } + CheckUserSanitization(t, users[0]) + + _, resp = Client.GetUsersByIds([]string{}) + CheckBadRequestStatus(t, resp) + + users, resp = Client.GetUsersByIds([]string{"junk"}) + CheckNoError(t, resp) + if len(users) > 0 { + t.Fatal("no users should be returned") + } + + users, resp = Client.GetUsersByIds([]string{"junk", th.BasicUser.Id}) + CheckNoError(t, resp) + if len(users) != 1 { + t.Fatal("1 user should be returned") + } + + Client.Logout() + _, resp = Client.GetUsersByIds([]string{th.BasicUser.Id}) + CheckUnauthorizedStatus(t, resp) +} + func TestUpdateUser(t *testing.T) { th := Setup().InitBasic().InitSystemAdmin() defer TearDown() diff --git a/app/user.go b/app/user.go index 5422d0b67..f9137b1e9 100644 --- a/app/user.go +++ b/app/user.go @@ -417,11 +417,17 @@ func GetUsersNotInChannel(teamId string, channelId string, offset int, limit int } } -func GetUsersByIds(userIds []string) (map[string]*model.User, *model.AppError) { +func GetUsersByIds(userIds []string, asAdmin bool) ([]*model.User, *model.AppError) { if result := <-Srv.Store.User().GetProfileByIds(userIds, true); result.Err != nil { return nil, result.Err } else { - return result.Data.(map[string]*model.User), nil + users := result.Data.([]*model.User) + + for _, u := range users { + SanitizeProfile(u, asAdmin) + } + + return users, nil } } diff --git a/model/client4.go b/model/client4.go index c82f5ce0e..fb314e26d 100644 --- a/model/client4.go +++ b/model/client4.go @@ -210,6 +210,16 @@ func (c *Client4) GetUser(userId, etag string) (*User, *Response) { } } +// GetUsersByIds returns a list of users based on the provided user ids. +func (c *Client4) GetUsersByIds(userIds []string) ([]*User, *Response) { + if r, err := c.DoApiPost(c.GetUsersRoute()+"/ids", ArrayToJson(userIds)); err != nil { + return nil, &Response{StatusCode: r.StatusCode, Error: err} + } else { + defer closeBody(r) + return UserListFromJson(r.Body), BuildResponse(r) + } +} + // UpdateUser updates a user in the system based on the provided user struct. func (c *Client4) UpdateUser(user *User) (*User, *Response) { if r, err := c.DoApiPut(c.GetUserRoute(user.Id), user.ToJson()); err != nil { diff --git a/store/sql_user_store.go b/store/sql_user_store.go index 02cbb3fbf..827c5a064 100644 --- a/store/sql_user_store.go +++ b/store/sql_user_store.go @@ -808,8 +808,7 @@ func (us SqlUserStore) GetProfileByIds(userIds []string, allowFromCache bool) St result := StoreResult{} metrics := einterfaces.GetMetricsInterface() - var users []*model.User - userMap := make(map[string]*model.User) + users := []*model.User{} props := make(map[string]interface{}) idQuery := "" remainingUserIds := make([]string, 0) @@ -818,13 +817,13 @@ func (us SqlUserStore) GetProfileByIds(userIds []string, allowFromCache bool) St for _, userId := range userIds { if cacheItem, ok := profileByIdsCache.Get(userId); ok { u := cacheItem.(*model.User) - userMap[u.Id] = u + users = append(users, u) } else { remainingUserIds = append(remainingUserIds, userId) } } if metrics != nil { - metrics.AddMemCacheHitCounter("Profile By Ids", float64(len(userMap))) + metrics.AddMemCacheHitCounter("Profile By Ids", float64(len(users))) metrics.AddMemCacheMissCounter("Profile By Ids", float64(len(remainingUserIds))) } } else { @@ -836,7 +835,7 @@ func (us SqlUserStore) GetProfileByIds(userIds []string, allowFromCache bool) St // If everything came from the cache then just return if len(remainingUserIds) == 0 { - result.Data = userMap + result.Data = users storeChannel <- result close(storeChannel) return @@ -859,11 +858,10 @@ func (us SqlUserStore) GetProfileByIds(userIds []string, allowFromCache bool) St u.Password = "" u.AuthData = new(string) *u.AuthData = "" - userMap[u.Id] = u profileByIdsCache.AddWithExpiresInSecs(u.Id, u, PROFILE_BY_IDS_CACHE_SEC) } - result.Data = userMap + result.Data = users } storeChannel <- result diff --git a/store/sql_user_store_test.go b/store/sql_user_store_test.go index fb04e95c9..449c6aa52 100644 --- a/store/sql_user_store_test.go +++ b/store/sql_user_store_test.go @@ -456,78 +456,120 @@ func TestUserStoreGetProfilesByIds(t *testing.T) { if r1 := <-store.User().GetProfileByIds([]string{u1.Id}, false); r1.Err != nil { t.Fatal(r1.Err) } else { - users := r1.Data.(map[string]*model.User) + users := r1.Data.([]*model.User) if len(users) != 1 { t.Fatal("invalid returned users") } - if users[u1.Id].Id != u1.Id { - t.Fatal("invalid returned user") + found := false + for _, u := range users { + if u.Id == u1.Id { + found = true + } + } + + if !found { + t.Fatal("missing user") } } if r1 := <-store.User().GetProfileByIds([]string{u1.Id}, true); r1.Err != nil { t.Fatal(r1.Err) } else { - users := r1.Data.(map[string]*model.User) + users := r1.Data.([]*model.User) if len(users) != 1 { t.Fatal("invalid returned users") } - if users[u1.Id].Id != u1.Id { - t.Fatal("invalid returned user") + found := false + for _, u := range users { + if u.Id == u1.Id { + found = true + } + } + + if !found { + t.Fatal("missing user") } } if r1 := <-store.User().GetProfileByIds([]string{u1.Id, u2.Id}, true); r1.Err != nil { t.Fatal(r1.Err) } else { - users := r1.Data.(map[string]*model.User) + users := r1.Data.([]*model.User) if len(users) != 2 { t.Fatal("invalid returned users") } - if users[u1.Id].Id != u1.Id { - t.Fatal("invalid returned user") + found := false + for _, u := range users { + if u.Id == u1.Id { + found = true + } + } + + if !found { + t.Fatal("missing user") } } if r1 := <-store.User().GetProfileByIds([]string{u1.Id, u2.Id}, true); r1.Err != nil { t.Fatal(r1.Err) } else { - users := r1.Data.(map[string]*model.User) + users := r1.Data.([]*model.User) if len(users) != 2 { t.Fatal("invalid returned users") } - if users[u1.Id].Id != u1.Id { - t.Fatal("invalid returned user") + found := false + for _, u := range users { + if u.Id == u1.Id { + found = true + } + } + + if !found { + t.Fatal("missing user") } } if r1 := <-store.User().GetProfileByIds([]string{u1.Id, u2.Id}, false); r1.Err != nil { t.Fatal(r1.Err) } else { - users := r1.Data.(map[string]*model.User) + users := r1.Data.([]*model.User) if len(users) != 2 { t.Fatal("invalid returned users") } - if users[u1.Id].Id != u1.Id { - t.Fatal("invalid returned user") + found := false + for _, u := range users { + if u.Id == u1.Id { + found = true + } + } + + if !found { + t.Fatal("missing user") } } if r1 := <-store.User().GetProfileByIds([]string{u1.Id}, false); r1.Err != nil { t.Fatal(r1.Err) } else { - users := r1.Data.(map[string]*model.User) + users := r1.Data.([]*model.User) if len(users) != 1 { t.Fatal("invalid returned users") } - if users[u1.Id].Id != u1.Id { - t.Fatal("invalid returned user") + found := false + for _, u := range users { + if u.Id == u1.Id { + found = true + } + } + + if !found { + t.Fatal("missing user") } } -- cgit v1.2.3-1-g7c22