diff options
Diffstat (limited to 'app')
-rw-r--r-- | app/ratelimit.go | 11 | ||||
-rw-r--r-- | app/ratelimit_test.go | 17 | ||||
-rw-r--r-- | app/server.go | 9 |
3 files changed, 28 insertions, 9 deletions
diff --git a/app/ratelimit.go b/app/ratelimit.go index 460088598..13508f36f 100644 --- a/app/ratelimit.go +++ b/app/ratelimit.go @@ -12,6 +12,7 @@ import ( l4g "github.com/alecthomas/log4go" "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/utils" + "github.com/pkg/errors" throttled "gopkg.in/throttled/throttled.v2" "gopkg.in/throttled/throttled.v2/store/memstore" ) @@ -23,11 +24,10 @@ type RateLimiter struct { header string } -func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter { +func NewRateLimiter(settings *model.RateLimitSettings) (*RateLimiter, error) { store, err := memstore.New(*settings.MemoryStoreSize) if err != nil { - l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store")) - return nil + return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store")) } quota := throttled.RateQuota{ @@ -37,8 +37,7 @@ func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter { throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota) if err != nil { - l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter")) - return nil + return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter")) } return &RateLimiter{ @@ -46,7 +45,7 @@ func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter { useAuth: *settings.VaryByUser, useIP: *settings.VaryByRemoteAddr, header: settings.VaryByHeader, - } + }, nil } func (rl *RateLimiter) GenerateKey(r *http.Request) string { diff --git a/app/ratelimit_test.go b/app/ratelimit_test.go index ddaa25710..fb157b2b0 100644 --- a/app/ratelimit_test.go +++ b/app/ratelimit_test.go @@ -25,6 +25,21 @@ func genRateLimitSettings(useAuth, useIP bool, header string) *model.RateLimitSe } } +func TestNewRateLimiterSuccess(t *testing.T) { + settings := genRateLimitSettings(false, false, "") + rateLimiter, err := NewRateLimiter(settings) + require.NotNil(t, rateLimiter) + require.NoError(t, err) +} + +func TestNewRateLimiterFailure(t *testing.T) { + invalidSettings := genRateLimitSettings(false, false, "") + invalidSettings.MaxBurst = model.NewInt(-100) + rateLimiter, err := NewRateLimiter(invalidSettings) + require.Nil(t, rateLimiter) + require.Error(t, err) +} + func TestGenerateKey(t *testing.T) { cases := []struct { useAuth bool @@ -58,7 +73,7 @@ func TestGenerateKey(t *testing.T) { req.Header.Set(tc.header, tc.headerResult) } - rateLimiter := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header)) + rateLimiter, _ := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header)) key := rateLimiter.GenerateKey(req) diff --git a/app/server.go b/app/server.go index 2a94bf2c7..1659908b6 100644 --- a/app/server.go +++ b/app/server.go @@ -124,9 +124,14 @@ func (a *App) StartServer() { if *a.Config().RateLimitSettings.Enable { l4g.Info(utils.T("api.server.start_server.rate.info")) - a.Srv.RateLimiter = NewRateLimiter(&a.Config().RateLimitSettings) + rateLimiter, err := NewRateLimiter(&a.Config().RateLimitSettings) + if err != nil { + l4g.Critical(err.Error()) + return + } - handler = a.Srv.RateLimiter.RateLimitHandler(handler) + a.Srv.RateLimiter = rateLimiter + handler = rateLimiter.RateLimitHandler(handler) } a.Srv.Server = &http.Server{ |