diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/config.go | 3 | ||||
-rw-r--r-- | utils/file_backend_s3.go | 18 | ||||
-rw-r--r-- | utils/file_backend_s3_test.go | 32 | ||||
-rw-r--r-- | utils/i18n.go | 5 | ||||
-rw-r--r-- | utils/log.go | 33 | ||||
-rw-r--r-- | utils/logger/log4go_json_writer.go | 30 | ||||
-rw-r--r-- | utils/logger/logger.go | 222 | ||||
-rw-r--r-- | utils/lru.go | 113 | ||||
-rw-r--r-- | utils/lru_test.go | 33 | ||||
-rw-r--r-- | utils/mail.go | 64 | ||||
-rw-r--r-- | utils/mail_test.go | 64 |
11 files changed, 211 insertions, 406 deletions
diff --git a/utils/config.go b/utils/config.go index c4d3d0d96..93a870743 100644 --- a/utils/config.go +++ b/utils/config.go @@ -26,9 +26,6 @@ import ( ) const ( - MODE_DEV = "dev" - MODE_BETA = "beta" - MODE_PROD = "prod" LOG_ROTATE_SIZE = 10000 LOG_FILENAME = "mattermost.log" ) diff --git a/utils/file_backend_s3.go b/utils/file_backend_s3.go index 8e72272a1..75282897f 100644 --- a/utils/file_backend_s3.go +++ b/utils/file_backend_s3.go @@ -37,7 +37,10 @@ type S3FileBackend struct { // disables automatic region lookup. func (b *S3FileBackend) s3New() (*s3.Client, error) { var creds *credentials.Credentials - if b.signV2 { + + if b.accessKey == "" && b.secretKey == "" { + creds = credentials.NewIAM("") + } else if b.signV2 { creds = credentials.NewStatic(b.accessKey, b.secretKey, "", credentials.SignatureV2) } else { creds = credentials.NewStatic(b.accessKey, b.secretKey, "", credentials.SignatureV4) @@ -244,3 +247,16 @@ func s3CopyMetadata(encrypt bool) map[string]string { metaData["x-amz-server-side-encryption"] = "AES256" return metaData } + +func CheckMandatoryS3Fields(settings *model.FileSettings) *model.AppError { + if len(settings.AmazonS3Bucket) == 0 { + return model.NewAppError("S3File", "api.admin.test_s3.missing_s3_bucket", nil, "", http.StatusBadRequest) + } + + // if S3 endpoint is not set call the set defaults to set that + if len(settings.AmazonS3Endpoint) == 0 { + settings.SetDefaults() + } + + return nil +} diff --git a/utils/file_backend_s3_test.go b/utils/file_backend_s3_test.go new file mode 100644 index 000000000..a8834f226 --- /dev/null +++ b/utils/file_backend_s3_test.go @@ -0,0 +1,32 @@ +// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package utils + +import ( + "testing" + + "github.com/mattermost/mattermost-server/model" +) + +func TestCheckMandatoryS3Fields(t *testing.T) { + cfg := model.FileSettings{} + + err := CheckMandatoryS3Fields(&cfg) + if err == nil || err.Message != "api.admin.test_s3.missing_s3_bucket" { + t.Fatal("should've failed with missing s3 bucket") + } + + cfg.AmazonS3Bucket = "test-mm" + err = CheckMandatoryS3Fields(&cfg) + if err != nil { + t.Fatal("should've not failed") + } + + cfg.AmazonS3Endpoint = "" + err = CheckMandatoryS3Fields(&cfg) + if err != nil || cfg.AmazonS3Endpoint != "s3.amazonaws.com" { + t.Fatal("should've not failed because it should set the endpoint to the default") + } + +} diff --git a/utils/i18n.go b/utils/i18n.go index 71e1aaee1..8ed82d19f 100644 --- a/utils/i18n.go +++ b/utils/i18n.go @@ -91,11 +91,6 @@ func GetUserTranslations(locale string) i18n.TranslateFunc { return translations } -func SetTranslations(locale string) i18n.TranslateFunc { - translations := TfuncWithFallback(locale) - return translations -} - func GetTranslationsAndLocale(w http.ResponseWriter, r *http.Request) (i18n.TranslateFunc, string) { // This is for checking against locales like pt_BR or zn_CN headerLocaleFull := strings.Split(r.Header.Get("Accept-Language"), ",")[0] diff --git a/utils/log.go b/utils/log.go deleted file mode 100644 index c1f579e9d..000000000 --- a/utils/log.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See License.txt for license information. - -package utils - -import ( - "bytes" - "io" - "io/ioutil" - - l4g "github.com/alecthomas/log4go" -) - -// InfoReader logs the content of the io.Reader and returns a new io.Reader -// with the same content as the received io.Reader. -// If you pass reader by reference, it won't be re-created unless the loglevel -// includes Debug. -// If an error is returned, the reader is consumed an cannot be read again. -func InfoReader(reader io.Reader, message string) (io.Reader, error) { - var err error - l4g.Info(func() string { - content, err := ioutil.ReadAll(reader) - if err != nil { - return "" - } - - reader = bytes.NewReader(content) - - return message + string(content) - }) - - return reader, err -} diff --git a/utils/logger/log4go_json_writer.go b/utils/logger/log4go_json_writer.go deleted file mode 100644 index ede541b2b..000000000 --- a/utils/logger/log4go_json_writer.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See License.txt for license information. - -// glue functions that allow logger.go to leverage log4Go to write JSON-formatted log records to a file - -package logger - -import ( - l4g "github.com/alecthomas/log4go" - "github.com/mattermost/mattermost-server/utils" -) - -// newJSONLogWriter is a utility method for creating a FileLogWriter set up to -// output JSON record log messages instead of line-based ones. -func newJSONLogWriter(fname string, rotate bool) *l4g.FileLogWriter { - return l4g.NewFileLogWriter(fname, rotate).SetFormat( - `{"level": "%L", - "timestamp": "%D %T", - "source": "%S", - "message": %M - }`).SetRotateLines(utils.LOG_ROTATE_SIZE) -} - -// NewJSONFileLogger - Create a new logger with a "file" filter configured to send JSON-formatted log messages at -// or above lvl to a file with the specified filename. -func NewJSONFileLogger(lvl l4g.Level, filename string) l4g.Logger { - return l4g.Logger{ - "file": &l4g.Filter{Level: lvl, LogWriter: newJSONLogWriter(filename, false)}, - } -} diff --git a/utils/logger/logger.go b/utils/logger/logger.go deleted file mode 100644 index 558f3fe47..000000000 --- a/utils/logger/logger.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See License.txt for license information. - -// this is a new logger interface for mattermost - -package logger - -import ( - "context" - "encoding/json" - "fmt" - "path/filepath" - "runtime" - - l4g "github.com/alecthomas/log4go" - - "strings" - - "github.com/mattermost/mattermost-server/model" - "github.com/mattermost/mattermost-server/utils" - "github.com/pkg/errors" -) - -// this pattern allows us to "mock" the underlying l4g code when unit testing -var logger l4g.Logger -var debugLog = l4g.Debug -var infoLog = l4g.Info -var errorLog = l4g.Error - -// assumes that ../config.go::configureLog has already been called, and has in turn called l4g.close() to clean up -// any old filters that we might have previously created -func initL4g(logSettings model.LogSettings) { - // TODO: add support for newConfig.LogSettings.EnableConsole. Right now, ../config.go sets it up in its configureLog - // method. If we also set it up here, messages will be written to the console twice. Eventually, when all instances - // of l4g have been replaced by this logger, we can move that code to here - if logSettings.EnableFile { - level := l4g.DEBUG - if logSettings.FileLevel == "INFO" { - level = l4g.INFO - } else if logSettings.FileLevel == "WARN" { - level = l4g.WARNING - } else if logSettings.FileLevel == "ERROR" { - level = l4g.ERROR - } - - // create a logger that writes JSON objects to a file, and override our log methods to use it - if logger != nil { - logger.Close() - } - logger = NewJSONFileLogger(level, utils.GetLogFileLocation(logSettings.FileLocation)+".jsonl") - debugLog = logger.Debug - infoLog = logger.Info - errorLog = logger.Error - } -} - -// contextKey lets us add contextual information to log messages -type contextKey string - -func (c contextKey) String() string { - return string(c) -} - -const contextKeyUserID contextKey = contextKey("user_id") -const contextKeyRequestID contextKey = contextKey("request_id") - -// any contextKeys added to this array will be serialized in every log message -var contextKeys = [2]contextKey{contextKeyUserID, contextKeyRequestID} - -// WithUserId adds a user id to the specified context. If the returned Context is subsequently passed to a logging -// method, the user id will automatically be included in the logged message -func WithUserId(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, contextKeyUserID, userID) -} - -// WithRequestId adds a request id to the specified context. If the returned Context is subsequently passed to a logging -// method, the request id will automatically be included in the logged message -func WithRequestId(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, contextKeyRequestID, requestID) -} - -// extracts known contextKey values from the specified Context and assembles them into the returned map -func serializeContext(ctx context.Context) map[string]string { - serialized := make(map[string]string) - for _, key := range contextKeys { - value, ok := ctx.Value(key).(string) - if ok { - serialized[string(key)] = value - } - } - return serialized -} - -// Returns the path to the next file up the callstack that has a different name than this file -// in other words, finds the path to the file that is doing the logging. -// Removes machine-specific prefix, so returned path starts with /mattermost-server. -// Looks a maximum of 10 frames up the call stack to find a file that has a different name than this one. -func getCallerFilename() (string, error) { - _, currentFilename, _, ok := runtime.Caller(0) - if !ok { - return "", errors.New("Failed to traverse stack frame") - } - - platformDirectory := currentFilename - for filepath.Base(platformDirectory) != "platform" { - platformDirectory = filepath.Dir(platformDirectory) - if platformDirectory == "." || platformDirectory == string(filepath.Separator) { - break - } - } - - for i := 1; i < 10; i++ { - _, parentFilename, _, ok := runtime.Caller(i) - if !ok { - return "", errors.New("Failed to traverse stack frame") - } else if parentFilename != currentFilename && strings.Contains(parentFilename, platformDirectory) { - // trim parentFilename such that we return the path to parentFilename, relative to platformDirectory - return parentFilename[strings.LastIndex(parentFilename, platformDirectory)+len(platformDirectory)+1:], nil - } - } - return "", errors.New("Failed to traverse stack frame") -} - -// creates a JSON representation of a log message -func serializeLogMessage(ctx context.Context, message string) string { - callerFilename, err := getCallerFilename() - if err != nil { - callerFilename = "Unknown" - } - - bytes, err := json.Marshal(&struct { - Context map[string]string `json:"context"` - File string `json:"file"` - Message string `json:"message"` - }{ - serializeContext(ctx), - callerFilename, - message, - }) - if err != nil { - errorLog("Failed to serialize log message %v", message) - } - return string(bytes) -} - -func formatMessage(args ...interface{}) string { - msg, ok := args[0].(string) - if !ok { - panic("Second argument is not of type string") - } - if len(args) > 1 { - variables := args[1:] - msg = fmt.Sprintf(msg, variables...) - } - return msg -} - -// Debugc logs a debugLog level message, including context information that is stored in the first parameter. -// If two parameters are supplied, the second must be a message string, and will be logged directly. -// If more than two parameters are supplied, the second parameter must be a format string, and the remaining parameters -// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function. -func Debugc(ctx context.Context, args ...interface{}) { - debugLog(func() string { - msg := formatMessage(args...) - return serializeLogMessage(ctx, msg) - }) -} - -// Debugf logs a debugLog level message. -// If one parameter is supplied, it must be a message string, and will be logged directly. -// If two or more parameters are specified, the first parameter must be a format string, and the remaining parameters -// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function. -func Debugf(args ...interface{}) { - debugLog(func() string { - msg := formatMessage(args...) - return serializeLogMessage(context.Background(), msg) - }) -} - -// Infoc logs an infoLog level message, including context information that is stored in the first parameter. -// If two parameters are supplied, the second must be a message string, and will be logged directly. -// If more than two parameters are supplied, the second parameter must be a format string, and the remaining parameters -// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function. -func Infoc(ctx context.Context, args ...interface{}) { - infoLog(func() string { - msg := formatMessage(args...) - return serializeLogMessage(ctx, msg) - }) -} - -// Infof logs an infoLog level message. -// If one parameter is supplied, it must be a message string, and will be logged directly. -// If two or more parameters are specified, the first parameter must be a format string, and the remaining parameters -// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function. -func Infof(args ...interface{}) { - infoLog(func() string { - msg := formatMessage(args...) - return serializeLogMessage(context.Background(), msg) - }) -} - -// Errorc logs an error level message, including context information that is stored in the first parameter. -// If two parameters are supplied, the second must be a message string, and will be logged directly. -// If more than two parameters are supplied, the second parameter must be a format string, and the remaining parameters -// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function. -func Errorc(ctx context.Context, args ...interface{}) { - errorLog(func() string { - msg := formatMessage(args...) - return serializeLogMessage(ctx, msg) - }) -} - -// Errorf logs an error level message. -// If one parameter is supplied, it must be a message string, and will be logged directly. -// If two or more parameters are specified, the first parameter must be a format string, and the remaining parameters -// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function. -func Errorf(args ...interface{}) { - errorLog(func() string { - msg := formatMessage(args...) - return serializeLogMessage(context.Background(), msg) - }) -} diff --git a/utils/lru.go b/utils/lru.go index 576331563..8e896a6dc 100644 --- a/utils/lru.go +++ b/utils/lru.go @@ -9,15 +9,14 @@ package utils import ( "container/list" - "errors" "sync" "time" ) // Caching Interface type ObjectCache interface { - AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) bool - AddWithDefaultExpires(key, value interface{}) bool + AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) + AddWithDefaultExpires(key, value interface{}) Purge() Get(key interface{}) (value interface{}, ok bool) Remove(key interface{}) @@ -32,10 +31,11 @@ type Cache struct { evictList *list.List items map[interface{}]*list.Element lock sync.RWMutex - onEvicted func(key interface{}, value interface{}) name string defaultExpiry int64 invalidateClusterEvent string + currentGeneration int64 + len int } // entry is used to hold a value in the evictList @@ -43,25 +43,16 @@ type entry struct { key interface{} value interface{} expireAtSecs int64 + generation int64 } // New creates an LRU of the given size func NewLru(size int) *Cache { - cache, _ := NewLruWithEvict(size, nil) - return cache -} - -func NewLruWithEvict(size int, onEvicted func(key interface{}, value interface{})) (*Cache, error) { - if size <= 0 { - return nil, errors.New(T("utils.iru.with_evict")) - } - c := &Cache{ + return &Cache{ size: size, evictList: list.New(), items: make(map[interface{}]*list.Element, size), - onEvicted: onEvicted, } - return c, nil } func NewLruWithParams(size int, name string, defaultExpiry int64, invalidateClusterEvent string) *Cache { @@ -77,26 +68,19 @@ func (c *Cache) Purge() { c.lock.Lock() defer c.lock.Unlock() - if c.onEvicted != nil { - for k, v := range c.items { - c.onEvicted(k, v.Value) - } - } - - c.evictList = list.New() - c.items = make(map[interface{}]*list.Element, c.size) + c.len = 0 + c.currentGeneration++ } -func (c *Cache) Add(key, value interface{}) bool { - return c.AddWithExpiresInSecs(key, value, 0) +func (c *Cache) Add(key, value interface{}) { + c.AddWithExpiresInSecs(key, value, 0) } -func (c *Cache) AddWithDefaultExpires(key, value interface{}) bool { - return c.AddWithExpiresInSecs(key, value, c.defaultExpiry) +func (c *Cache) AddWithDefaultExpires(key, value interface{}) { + c.AddWithExpiresInSecs(key, value, c.defaultExpiry) } -// Add adds a value to the cache. Returns true if an eviction occurred. -func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) bool { +func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) { c.lock.Lock() defer c.lock.Unlock() @@ -107,45 +91,46 @@ func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) // Check for existing item if ent, ok := c.items[key]; ok { c.evictList.MoveToFront(ent) - ent.Value.(*entry).value = value - ent.Value.(*entry).expireAtSecs = expireAtSecs - return false + e := ent.Value.(*entry) + e.value = value + e.expireAtSecs = expireAtSecs + if e.generation != c.currentGeneration { + e.generation = c.currentGeneration + c.len++ + } + return } // Add new item - ent := &entry{key, value, expireAtSecs} + ent := &entry{key, value, expireAtSecs, c.currentGeneration} entry := c.evictList.PushFront(ent) c.items[key] = entry + c.len++ - evict := c.evictList.Len() > c.size - // Verify size not exceeded - if evict { - c.removeOldest() + if c.evictList.Len() > c.size { + c.removeElement(c.evictList.Back()) } - return evict } -// Get looks up a key's value from the cache. func (c *Cache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() if ent, ok := c.items[key]; ok { + e := ent.Value.(*entry) - if ent.Value.(*entry).expireAtSecs > 0 { - if (time.Now().UnixNano() / int64(time.Second)) > ent.Value.(*entry).expireAtSecs { - c.removeElement(ent) - return nil, false - } + if e.generation != c.currentGeneration || (e.expireAtSecs > 0 && (time.Now().UnixNano()/int64(time.Second)) > e.expireAtSecs) { + c.removeElement(ent) + return nil, false } c.evictList.MoveToFront(ent) return ent.Value.(*entry).value, true } - return + + return nil, false } -// Remove removes the provided key from the cache. func (c *Cache) Remove(key interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -155,25 +140,19 @@ func (c *Cache) Remove(key interface{}) { } } -// RemoveOldest removes the oldest item from the cache. -func (c *Cache) RemoveOldest() { - c.lock.Lock() - defer c.lock.Unlock() - c.removeOldest() -} - // Keys returns a slice of the keys in the cache, from oldest to newest. func (c *Cache) Keys() []interface{} { c.lock.RLock() defer c.lock.RUnlock() - keys := make([]interface{}, len(c.items)) - ent := c.evictList.Back() + keys := make([]interface{}, c.len) i := 0 - for ent != nil { - keys[i] = ent.Value.(*entry).key - ent = ent.Prev() - i++ + for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() { + e := ent.Value.(*entry) + if e.generation == c.currentGeneration { + keys[i] = e.key + i++ + } } return keys @@ -183,7 +162,7 @@ func (c *Cache) Keys() []interface{} { func (c *Cache) Len() int { c.lock.RLock() defer c.lock.RUnlock() - return c.evictList.Len() + return c.len } func (c *Cache) Name() string { @@ -194,20 +173,12 @@ func (c *Cache) GetInvalidateClusterEvent() string { return c.invalidateClusterEvent } -// removeOldest removes the oldest item from the cache. -func (c *Cache) removeOldest() { - ent := c.evictList.Back() - if ent != nil { - c.removeElement(ent) - } -} - // removeElement is used to remove a given list element from the cache func (c *Cache) removeElement(e *list.Element) { c.evictList.Remove(e) kv := e.Value.(*entry) - delete(c.items, kv.key) - if c.onEvicted != nil { - c.onEvicted(kv.key, kv.value) + if kv.generation == c.currentGeneration { + c.len-- } + delete(c.items, kv.key) } diff --git a/utils/lru_test.go b/utils/lru_test.go index 987163cd3..4312515b9 100644 --- a/utils/lru_test.go +++ b/utils/lru_test.go @@ -11,14 +11,7 @@ import "testing" import "time" func TestLRU(t *testing.T) { - evictCounter := 0 - onEvicted := func(k interface{}, v interface{}) { - evictCounter += 1 - } - l, err := NewLruWithEvict(128, onEvicted) - if err != nil { - t.Fatalf("err: %v", err) - } + l := NewLru(128) for i := 0; i < 256; i++ { l.Add(i, i) @@ -27,10 +20,6 @@ func TestLRU(t *testing.T) { t.Fatalf("bad len: %v", l.Len()) } - if evictCounter != 128 { - t.Fatalf("bad evict count: %v", evictCounter) - } - for i, k := range l.Keys() { if v, ok := l.Get(k); !ok || v != k || v != i+128 { t.Fatalf("bad key: %v", k) @@ -73,26 +62,6 @@ func TestLRU(t *testing.T) { } } -// test that Add return true/false if an eviction occurred -func TestLRUAdd(t *testing.T) { - evictCounter := 0 - onEvicted := func(k interface{}, v interface{}) { - evictCounter += 1 - } - - l, err := NewLruWithEvict(1, onEvicted) - if err != nil { - t.Fatalf("err: %v", err) - } - - if l.Add(1, 1) || evictCounter != 0 { - t.Errorf("should not have an eviction") - } - if !l.Add(2, 2) || evictCounter != 1 { - t.Errorf("should have an eviction") - } -} - func TestLRUExpire(t *testing.T) { l := NewLru(128) diff --git a/utils/mail.go b/utils/mail.go index 9023f7090..3b9f4bd9d 100644 --- a/utils/mail.go +++ b/utils/mail.go @@ -5,6 +5,8 @@ package utils import ( "crypto/tls" + "errors" + "io" "mime" "net" "net/mail" @@ -15,8 +17,6 @@ import ( "net/http" - "io" - l4g "github.com/alecthomas/log4go" "github.com/mattermost/html2text" "github.com/mattermost/mattermost-server/model" @@ -26,6 +26,56 @@ func encodeRFC2047Word(s string) string { return mime.BEncoding.Encode("utf-8", s) } +type authChooser struct { + smtp.Auth + Config *model.Config +} + +func (a *authChooser) Start(server *smtp.ServerInfo) (string, []byte, error) { + a.Auth = LoginAuth(a.Config.EmailSettings.SMTPUsername, a.Config.EmailSettings.SMTPPassword, a.Config.EmailSettings.SMTPServer+":"+a.Config.EmailSettings.SMTPPort) + for _, method := range server.Auth { + if method == "PLAIN" { + a.Auth = smtp.PlainAuth("", a.Config.EmailSettings.SMTPUsername, a.Config.EmailSettings.SMTPPassword, a.Config.EmailSettings.SMTPServer+":"+a.Config.EmailSettings.SMTPPort) + break + } + } + return a.Auth.Start(server) +} + +type loginAuth struct { + username, password, host string +} + +func LoginAuth(username, password, host string) smtp.Auth { + return &loginAuth{username, password, host} +} + +func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + if !server.TLS { + return "", nil, errors.New("unencrypted connection") + } + + if server.Name != a.host { + return "", nil, errors.New("wrong host name") + } + + return "LOGIN", []byte{}, nil +} + +func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + switch string(fromServer) { + case "Username:": + return []byte(a.username), nil + case "Password:": + return []byte(a.password), nil + default: + return nil, errors.New("Unkown fromServer") + } + } + return nil, nil +} + func connectToSMTPServer(config *model.Config) (net.Conn, *model.AppError) { var conn net.Conn var err error @@ -75,9 +125,7 @@ func newSMTPClient(conn net.Conn, config *model.Config) (*smtp.Client, *model.Ap } if *config.EmailSettings.EnableSMTPAuth { - auth := smtp.PlainAuth("", config.EmailSettings.SMTPUsername, config.EmailSettings.SMTPPassword, config.EmailSettings.SMTPServer+":"+config.EmailSettings.SMTPPort) - - if err = c.Auth(auth); err != nil { + if err = c.Auth(&authChooser{Config: config}); err != nil { return nil, model.NewAppError("SendMail", "utils.mail.new_client.auth.app_error", nil, err.Error(), http.StatusInternalServerError) } } @@ -138,10 +186,8 @@ func sendMail(mimeTo, smtpTo string, from mail.Address, subject, htmlBody string "Auto-Submitted": {"auto-generated"}, "Precedence": {"bulk"}, } - if mimeHeaders != nil { - for k, v := range mimeHeaders { - headers[k] = []string{encodeRFC2047Word(v)} - } + for k, v := range mimeHeaders { + headers[k] = []string{encodeRFC2047Word(v)} } m := gomail.NewMessage(gomail.SetCharset("UTF-8")) diff --git a/utils/mail_test.go b/utils/mail_test.go index 068c90c60..31a4f8996 100644 --- a/utils/mail_test.go +++ b/utils/mail_test.go @@ -7,6 +7,9 @@ import ( "strings" "testing" + "net/smtp" + + "github.com/mattermost/mattermost-server/model" "github.com/stretchr/testify/require" ) @@ -169,3 +172,64 @@ func TestSendMailUsingConfig(t *testing.T) { } } }*/ + +func TestAuthMethods(t *testing.T) { + config := model.Config{ + EmailSettings: model.EmailSettings{ + EnableSMTPAuth: model.NewBool(false), + SMTPUsername: "test", + SMTPPassword: "fakepass", + SMTPServer: "fakeserver", + SMTPPort: "25", + }, + } + + auth := &authChooser{Config: &config} + tests := []struct { + desc string + server *smtp.ServerInfo + err string + }{ + { + desc: "auth PLAIN success", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"PLAIN"}, TLS: true}, + }, + { + desc: "auth PLAIN unencrypted connection fail", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"PLAIN"}, TLS: false}, + err: "unencrypted connection", + }, + { + desc: "auth PLAIN wrong host name", + server: &smtp.ServerInfo{Name: "wrongServer:999", Auth: []string{"PLAIN"}, TLS: true}, + err: "wrong host name", + }, + { + desc: "auth LOGIN success", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"LOGIN"}, TLS: true}, + }, + { + desc: "auth LOGIN unencrypted connection fail", + server: &smtp.ServerInfo{Name: "wrongServer:999", Auth: []string{"LOGIN"}, TLS: true}, + err: "wrong host name", + }, + { + desc: "auth LOGIN wrong host name", + server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"LOGIN"}, TLS: false}, + err: "unencrypted connection", + }, + } + + for i, test := range tests { + t.Run(test.desc, func(t *testing.T) { + _, _, err := auth.Start(test.server) + got := "" + if err != nil { + got = err.Error() + } + if got != test.err { + t.Errorf("%d. got error = %q; want %q", i, got, test.err) + } + }) + } +} |