diff options
author | George Goldberg <george@gberg.me> | 2018-03-02 15:55:03 +0000 |
---|---|---|
committer | George Goldberg <george@gberg.me> | 2018-03-02 15:55:03 +0000 |
commit | 901acc9703ae58b625b44e7abfd02333b9bab951 (patch) | |
tree | 1a8fc17a85544bc7b8064874923e2fe6e3f44354 /utils | |
parent | 21afaf4bedcad578d4f876bb315d1072ccd296e6 (diff) | |
parent | 2b3b6051d265edf131d006b2eb14f55284faf1e5 (diff) | |
download | chat-901acc9703ae58b625b44e7abfd02333b9bab951.tar.gz chat-901acc9703ae58b625b44e7abfd02333b9bab951.tar.bz2 chat-901acc9703ae58b625b44e7abfd02333b9bab951.zip |
Merge branch 'master' into advanced-permissions-phase-1
Diffstat (limited to 'utils')
-rw-r--r-- | utils/api.go | 2 | ||||
-rw-r--r-- | utils/config.go | 11 | ||||
-rw-r--r-- | utils/file_backend_s3.go | 44 | ||||
-rw-r--r-- | utils/file_backend_s3_test.go | 32 | ||||
-rw-r--r-- | utils/file_backend_test.go | 23 | ||||
-rw-r--r-- | utils/lru.go | 113 | ||||
-rw-r--r-- | utils/lru_test.go | 33 | ||||
-rw-r--r-- | utils/mail.go | 58 | ||||
-rw-r--r-- | utils/mail_test.go | 68 |
9 files changed, 253 insertions, 131 deletions
diff --git a/utils/api.go b/utils/api.go index 51524074d..0f2640829 100644 --- a/utils/api.go +++ b/utils/api.go @@ -52,7 +52,7 @@ func RenderWebError(w http.ResponseWriter, r *http.Request, status int, params u http.Error(w, "", http.StatusInternalServerError) return } - destination := strings.TrimRight(GetSiteURL(), "/") + "/error?" + queryString + "&s=" + base64.URLEncoding.EncodeToString(signature) + destination := "/error?" + queryString + "&s=" + base64.URLEncoding.EncodeToString(signature) if status >= 300 && status < 400 { http.Redirect(w, r, destination, status) diff --git a/utils/config.go b/utils/config.go index 10ada1728..b28cf918d 100644 --- a/utils/config.go +++ b/utils/config.go @@ -34,15 +34,6 @@ const ( ) var originalDisableDebugLvl l4g.Level = l4g.DEBUG -var siteURL = "" - -func GetSiteURL() string { - return siteURL -} - -func SetSiteURL(url string) { - siteURL = strings.TrimRight(url, "/") -} // FindConfigFile attempts to find an existing configuration file. fileName can be an absolute or // relative path or name such as "/opt/mattermost/config.json" or simply "config.json". An empty @@ -353,8 +344,10 @@ func GenerateClientConfig(c *model.Config, diagnosticId string, license *model.L props["BuildEnterpriseReady"] = model.BuildEnterpriseReady props["SiteURL"] = strings.TrimRight(*c.ServiceSettings.SiteURL, "/") + props["WebsocketURL"] = strings.TrimRight(*c.ServiceSettings.WebsocketURL, "/") props["SiteName"] = c.TeamSettings.SiteName props["EnableTeamCreation"] = strconv.FormatBool(*c.TeamSettings.EnableTeamCreation) + props["EnableAPIv3"] = strconv.FormatBool(*c.ServiceSettings.EnableAPIv3) props["EnableUserCreation"] = strconv.FormatBool(c.TeamSettings.EnableUserCreation) props["EnableOpenServer"] = strconv.FormatBool(*c.TeamSettings.EnableOpenServer) props["RestrictDirectMessage"] = *c.TeamSettings.RestrictDirectMessage diff --git a/utils/file_backend_s3.go b/utils/file_backend_s3.go index 7ef150851..b0601bc8a 100644 --- a/utils/file_backend_s3.go +++ b/utils/file_backend_s3.go @@ -37,7 +37,10 @@ type S3FileBackend struct { // disables automatic region lookup. func (b *S3FileBackend) s3New() (*s3.Client, error) { var creds *credentials.Credentials - if b.signV2 { + + if b.accessKey == "" && b.secretKey == "" { + creds = credentials.NewIAM("") + } else if b.signV2 { creds = credentials.NewStatic(b.accessKey, b.secretKey, "", credentials.SignatureV2) } else { creds = credentials.NewStatic(b.accessKey, b.secretKey, "", credentials.SignatureV4) @@ -138,17 +141,15 @@ func (b *S3FileBackend) WriteFile(f []byte, path string) *model.AppError { return model.NewAppError("WriteFile", "api.file.write_file.s3.app_error", nil, err.Error(), http.StatusInternalServerError) } - options := s3.PutObjectOptions{} - if b.encrypt { - options.UserMetadata["x-amz-server-side-encryption"] = "AES256" - } - + var contentType string if ext := filepath.Ext(path); model.IsFileExtImage(ext) { - options.ContentType = model.GetImageMimeType(ext) + contentType = model.GetImageMimeType(ext) } else { - options.ContentType = "binary/octet-stream" + contentType = "binary/octet-stream" } + options := s3PutOptions(b.encrypt, contentType) + if _, err = s3Clnt.PutObject(b.bucket, path, bytes.NewReader(f), -1, options); err != nil { return model.NewAppError("WriteFile", "api.file.write_file.s3.app_error", nil, err.Error(), http.StatusInternalServerError) } @@ -230,8 +231,35 @@ func (b *S3FileBackend) RemoveDirectory(path string) *model.AppError { return nil } +func s3PutOptions(encrypt bool, contentType string) s3.PutObjectOptions { + options := s3.PutObjectOptions{} + if encrypt { + options.UserMetadata = make(map[string]string) + options.UserMetadata["x-amz-server-side-encryption"] = "AES256" + } + options.ContentType = contentType + + return options +} + func s3CopyMetadata(encrypt bool) map[string]string { metaData := make(map[string]string) metaData["x-amz-server-side-encryption"] = "AES256" return metaData } + +func CheckMandatoryS3Fields(settings *model.FileSettings) *model.AppError { + if len(settings.AmazonS3Bucket) == 0 { + return model.NewAppError("S3File", "api.admin.test_s3.missing_s3_bucket", nil, "", http.StatusBadRequest) + } + + if len(settings.AmazonS3Endpoint) == 0 { + return model.NewAppError("S3File", "api.admin.test_s3.missing_s3_endpoint", nil, "", http.StatusBadRequest) + } + + if len(settings.AmazonS3Region) == 0 { + return model.NewAppError("S3File", "api.admin.test_s3.missing_s3_region", nil, "", http.StatusBadRequest) + } + + return nil +} diff --git a/utils/file_backend_s3_test.go b/utils/file_backend_s3_test.go new file mode 100644 index 000000000..ff42a4d19 --- /dev/null +++ b/utils/file_backend_s3_test.go @@ -0,0 +1,32 @@ +// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package utils + +import ( + "testing" + + "github.com/mattermost/mattermost-server/model" +) + +func TestCheckMandatoryS3Fields(t *testing.T) { + cfg := model.FileSettings{} + + err := CheckMandatoryS3Fields(&cfg) + if err == nil || err.Message != "api.admin.test_s3.missing_s3_bucket" { + t.Fatal("should've failed with missing s3 bucket") + } + + cfg.AmazonS3Bucket = "test-mm" + err = CheckMandatoryS3Fields(&cfg) + if err == nil || err.Message != "api.admin.test_s3.missing_s3_endpoint" { + t.Fatal("should've failed with missing s3 endpoint") + } + + cfg.AmazonS3Endpoint = "s3.newendpoint.com" + err = CheckMandatoryS3Fields(&cfg) + if err == nil || err.Message != "api.admin.test_s3.missing_s3_region" { + t.Fatal("should've failed with missing s3 region") + } + +} diff --git a/utils/file_backend_test.go b/utils/file_backend_test.go index 46f75574e..2b8e2a527 100644 --- a/utils/file_backend_test.go +++ b/utils/file_backend_test.go @@ -36,6 +36,14 @@ func TestLocalFileBackendTestSuite(t *testing.T) { } func TestS3FileBackendTestSuite(t *testing.T) { + runBackendTest(t, false) +} + +func TestS3FileBackendTestSuiteWithEncryption(t *testing.T) { + runBackendTest(t, true) +} + +func runBackendTest(t *testing.T, encrypt bool) { s3Host := os.Getenv("CI_HOST") if s3Host == "" { s3Host = "dockerhost" @@ -56,6 +64,7 @@ func TestS3FileBackendTestSuite(t *testing.T) { AmazonS3Bucket: model.MINIO_BUCKET, AmazonS3Endpoint: s3Endpoint, AmazonS3SSL: model.NewBool(false), + AmazonS3SSE: model.NewBool(encrypt), }, }) } @@ -86,6 +95,20 @@ func (s *FileBackendTestSuite) TestReadWriteFile() { s.EqualValues(readString, "test") } +func (s *FileBackendTestSuite) TestReadWriteFileImage() { + b := []byte("testimage") + path := "tests/" + model.NewId() + ".png" + + s.Nil(s.backend.WriteFile(b, path)) + defer s.backend.RemoveFile(path) + + read, err := s.backend.ReadFile(path) + s.Nil(err) + + readString := string(read) + s.EqualValues(readString, "testimage") +} + func (s *FileBackendTestSuite) TestCopyFile() { b := []byte("test") path1 := "tests/" + model.NewId() diff --git a/utils/lru.go b/utils/lru.go index 576331563..8e896a6dc 100644 --- a/utils/lru.go +++ b/utils/lru.go @@ -9,15 +9,14 @@ package utils import ( "container/list" - "errors" "sync" "time" ) // Caching Interface type ObjectCache interface { - AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) bool - AddWithDefaultExpires(key, value interface{}) bool + AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) + AddWithDefaultExpires(key, value interface{}) Purge() Get(key interface{}) (value interface{}, ok bool) Remove(key interface{}) @@ -32,10 +31,11 @@ type Cache struct { evictList *list.List items map[interface{}]*list.Element lock sync.RWMutex - onEvicted func(key interface{}, value interface{}) name string defaultExpiry int64 invalidateClusterEvent string + currentGeneration int64 + len int } // entry is used to hold a value in the evictList @@ -43,25 +43,16 @@ type entry struct { key interface{} value interface{} expireAtSecs int64 + generation int64 } // New creates an LRU of the given size func NewLru(size int) *Cache { - cache, _ := NewLruWithEvict(size, nil) - return cache -} - -func NewLruWithEvict(size int, onEvicted func(key interface{}, value interface{})) (*Cache, error) { - if size <= 0 { - return nil, errors.New(T("utils.iru.with_evict")) - } - c := &Cache{ + return &Cache{ size: size, evictList: list.New(), items: make(map[interface{}]*list.Element, size), - onEvicted: onEvicted, } - return c, nil } func NewLruWithParams(size int, name string, defaultExpiry int64, invalidateClusterEvent string) *Cache { @@ -77,26 +68,19 @@ func (c *Cache) Purge() { c.lock.Lock() defer c.lock.Unlock() - if c.onEvicted != nil { - for k, v := range c.items { - c.onEvicted(k, v.Value) - } - } - - c.evictList = list.New() - c.items = make(map[interface{}]*list.Element, c.size) + c.len = 0 + c.currentGeneration++ } -func (c *Cache) Add(key, value interface{}) bool { - return c.AddWithExpiresInSecs(key, value, 0) +func (c *Cache) Add(key, value interface{}) { + c.AddWithExpiresInSecs(key, value, 0) } -func (c *Cache) AddWithDefaultExpires(key, value interface{}) bool { - return c.AddWithExpiresInSecs(key, value, c.defaultExpiry) +func (c *Cache) AddWithDefaultExpires(key, value interface{}) { + c.AddWithExpiresInSecs(key, value, c.defaultExpiry) } -// Add adds a value to the cache. Returns true if an eviction occurred. -func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) bool { +func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) { c.lock.Lock() defer c.lock.Unlock() @@ -107,45 +91,46 @@ func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) // Check for existing item if ent, ok := c.items[key]; ok { c.evictList.MoveToFront(ent) - ent.Value.(*entry).value = value - ent.Value.(*entry).expireAtSecs = expireAtSecs - return false + e := ent.Value.(*entry) + e.value = value + e.expireAtSecs = expireAtSecs + if e.generation != c.currentGeneration { + e.generation = c.currentGeneration + c.len++ + } + return } // Add new item - ent := &entry{key, value, expireAtSecs} + ent := &entry{key, value, expireAtSecs, c.currentGeneration} entry := c.evictList.PushFront(ent) c.items[key] = entry + c.len++ - evict := c.evictList.Len() > c.size - // Verify size not exceeded - if evict { - c.removeOldest() + if c.evictList.Len() > c.size { + c.removeElement(c.evictList.Back()) } - return evict } -// Get looks up a key's value from the cache. func (c *Cache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() if ent, ok := c.items[key]; ok { + e := ent.Value.(*entry) - if ent.Value.(*entry).expireAtSecs > 0 { - if (time.Now().UnixNano() / int64(time.Second)) > ent.Value.(*entry).expireAtSecs { - c.removeElement(ent) - return nil, false - } + if e.generation != c.currentGeneration || (e.expireAtSecs > 0 && (time.Now().UnixNano()/int64(time.Second)) > e.expireAtSecs) { + c.removeElement(ent) + return nil, false } c.evictList.MoveToFront(ent) return ent.Value.(*entry).value, true } - return + + return nil, false } -// Remove removes the provided key from the cache. func (c *Cache) Remove(key interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -155,25 +140,19 @@ func (c *Cache) Remove(key interface{}) { } } -// RemoveOldest removes the oldest item from the cache. -func (c *Cache) RemoveOldest() { - c.lock.Lock() - defer c.lock.Unlock() - c.removeOldest() -} - // Keys returns a slice of the keys in the cache, from oldest to newest. func (c *Cache) Keys() []interface{} { c.lock.RLock() defer c.lock.RUnlock() - keys := make([]interface{}, len(c.items)) - ent := c.evictList.Back() + keys := make([]interface{}, c.len) i := 0 - for ent != nil { - keys[i] = ent.Value.(*entry).key - ent = ent.Prev() - i++ + for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() { + e := ent.Value.(*entry) + if e.generation == c.currentGeneration { + keys[i] = e.key + i++ + } } return keys @@ -183,7 +162,7 @@ func (c *Cache) Keys() []interface{} { func (c *Cache) Len() int { c.lock.RLock() defer c.lock.RUnlock() - return c.evictList.Len() + return c.len } func (c *Cache) Name() string { @@ -194,20 +173,12 @@ func (c *Cache) GetInvalidateClusterEvent() string { return c.invalidateClusterEvent } -// removeOldest removes the oldest item from the cache. -func (c *Cache) removeOldest() { - ent := c.evictList.Back() - if ent != nil { - c.removeElement(ent) - } -} - // removeElement is used to remove a given list element from the cache func (c *Cache) removeElement(e *list.Element) { c.evictList.Remove(e) kv := e.Value.(*entry) - delete(c.items, kv.key) - if c.onEvicted != nil { - c.onEvicted(kv.key, kv.value) + if kv.generation == c.currentGeneration { + c.len-- } + delete(c.items, kv.key) } diff --git a/utils/lru_test.go b/utils/lru_test.go index 987163cd3..4312515b9 100644 --- a/utils/lru_test.go +++ b/utils/lru_test.go @@ -11,14 +11,7 @@ import "testing" import "time" func TestLRU(t *testing.T) { - evictCounter := 0 - onEvicted := func(k interface{}, v interface{}) { - evictCounter += 1 - } - l, err := NewLruWithEvict(128, onEvicted) - if err != nil { - t.Fatalf("err: %v", err) - } + l := NewLru(128) for i := 0; i < 256; i++ { l.Add(i, i) @@ -27,10 +20,6 @@ func TestLRU(t *testing.T) { t.Fatalf("bad len: %v", l.Len()) } - if evictCounter != 128 { - t.Fatalf("bad evict count: %v", evictCounter) - } - for i, k := range l.Keys() { if v, ok := l.Get(k); !ok || v != k || v != i+128 { t.Fatalf("bad key: %v", k) @@ -73,26 +62,6 @@ func TestLRU(t *testing.T) { } } -// test that Add return true/false if an eviction occurred -func TestLRUAdd(t *testing.T) { - evictCounter := 0 - onEvicted := func(k interface{}, v interface{}) { - evictCounter += 1 - } - - l, err := NewLruWithEvict(1, onEvicted) - if err != nil { - t.Fatalf("err: %v", err) - } - - if l.Add(1, 1) || evictCounter != 0 { - t.Errorf("should not have an eviction") - } - if !l.Add(2, 2) || evictCounter != 1 { - t.Errorf("should have an eviction") - } -} - func TestLRUExpire(t *testing.T) { l := NewLru(128) diff --git a/utils/mail.go b/utils/mail.go index 9023f7090..2bc0ce9e1 100644 --- a/utils/mail.go +++ b/utils/mail.go @@ -5,6 +5,8 @@ package utils import ( "crypto/tls" + "errors" + "io" "mime" "net" "net/mail" @@ -15,8 +17,6 @@ import ( "net/http" - "io" - l4g "github.com/alecthomas/log4go" "github.com/mattermost/html2text" "github.com/mattermost/mattermost-server/model" @@ -26,6 +26,56 @@ func encodeRFC2047Word(s string) string { return mime.BEncoding.Encode("utf-8", s) } +type authChooser struct { + smtp.Auth + Config *model.Config +} + +func (a *authChooser) Start(server *smtp.ServerInfo) (string, []byte, error) { + a.Auth = LoginAuth(a.Config.EmailSettings.SMTPUsername, a.Config.EmailSettings.SMTPPassword, a.Config.EmailSettings.SMTPServer+":"+a.Config.EmailSettings.SMTPPort) + for _, method := range server.Auth { + if method == "PLAIN" { + a.Auth = smtp.PlainAuth("", a.Config.EmailSettings.SMTPUsername, a.Config.EmailSettings.SMTPPassword, a.Config.EmailSettings.SMTPServer+":"+a.Config.EmailSettings.SMTPPort) + break + } + } + return a.Auth.Start(server) +} + +type loginAuth struct { + username, password, host string +} + +func LoginAuth(username, password, host string) smtp.Auth { + return &loginAuth{username, password, host} +} + +func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + if !server.TLS { + return "", nil, errors.New("unencrypted connection") + } + + if server.Name != a.host { + return "", nil, errors.New("wrong host name") + } + + return "LOGIN", []byte{}, nil +} + +func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + switch string(fromServer) { + case "Username:": + return []byte(a.username), nil + case "Password:": + return []byte(a.password), nil + default: + return nil, errors.New("Unkown fromServer") + } + } + return nil, nil +} + func connectToSMTPServer(config *model.Config) (net.Conn, *model.AppError) { var conn net.Conn var err error @@ -75,9 +125,7 @@ func newSMTPClient(conn net.Conn, config *model.Config) (*smtp.Client, *model.Ap } if *config.EmailSettings.EnableSMTPAuth { - auth := smtp.PlainAuth("", config.EmailSettings.SMTPUsername, config.EmailSettings.SMTPPassword, config.EmailSettings.SMTPServer+":"+config.EmailSettings.SMTPPort) - - if err = c.Auth(auth); err != nil { + if err = c.Auth(&authChooser{Config: config}); err != nil { return nil, model.NewAppError("SendMail", "utils.mail.new_client.auth.app_error", nil, err.Error(), http.StatusInternalServerError) } } diff --git a/utils/mail_test.go b/utils/mail_test.go index 67d108d45..31a4f8996 100644 --- a/utils/mail_test.go +++ b/utils/mail_test.go @@ -7,12 +7,9 @@ import ( "strings" "testing" - "net/mail" - - "fmt" + "net/smtp" "github.com/mattermost/mattermost-server/model" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -82,7 +79,7 @@ func TestSendMailUsingConfig(t *testing.T) { } } -func TestSendMailUsingConfigAdvanced(t *testing.T) { +/*func TestSendMailUsingConfigAdvanced(t *testing.T) { cfg, _, err := LoadConfig("config.json") require.Nil(t, err) T = GetUserTranslations("en") @@ -174,4 +171,65 @@ func TestSendMailUsingConfigAdvanced(t *testing.T) { } } } +}*/ + +func TestAuthMethods(t *testing.T) { + config := model.Config{ + EmailSettings: model.EmailSettings{ + EnableSMTPAuth: model.NewBool(false), + SMTPUsername: "test", + SMTPPassword: "fakepass", + SMTPServer: "fakeserver", + SMTPPort: "25", + }, + } + + auth := &authChooser{Config: &config} + tests := []struct { + desc string + server *smtp.ServerInfo + err string + }{ + { + desc: "auth PLAIN success", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"PLAIN"}, TLS: true}, + }, + { + desc: "auth PLAIN unencrypted connection fail", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"PLAIN"}, TLS: false}, + err: "unencrypted connection", + }, + { + desc: "auth PLAIN wrong host name", + server: &smtp.ServerInfo{Name: "wrongServer:999", Auth: []string{"PLAIN"}, TLS: true}, + err: "wrong host name", + }, + { + desc: "auth LOGIN success", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"LOGIN"}, TLS: true}, + }, + { + desc: "auth LOGIN unencrypted connection fail", + server: &smtp.ServerInfo{Name: "wrongServer:999", Auth: []string{"LOGIN"}, TLS: true}, + err: "wrong host name", + }, + { + desc: "auth LOGIN wrong host name", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"LOGIN"}, TLS: false}, + err: "unencrypted connection", + }, + } + + for i, test := range tests { + t.Run(test.desc, func(t *testing.T) { + _, _, err := auth.Start(test.server) + got := "" + if err != nil { + got = err.Error() + } + if got != test.err { + t.Errorf("%d. got error = %q; want %q", i, got, test.err) + } + }) + } } |