diff options
Diffstat (limited to 'store/sqlstore/store_test.go')
-rw-r--r-- | store/sqlstore/store_test.go | 104 |
1 files changed, 96 insertions, 8 deletions
diff --git a/store/sqlstore/store_test.go b/store/sqlstore/store_test.go index 605c73b6a..d99c7e441 100644 --- a/store/sqlstore/store_test.go +++ b/store/sqlstore/store_test.go @@ -4,22 +4,110 @@ package sqlstore import ( + "flag" + "os" + "sync" "testing" + "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/store" + "github.com/mattermost/mattermost-server/store/storetest" "github.com/mattermost/mattermost-server/utils" ) -var sqlStore store.Store +var storeTypes = []*struct { + Name string + Func func() (*storetest.RunningContainer, *model.SqlSettings, error) + Container *storetest.RunningContainer + Store store.Store +}{ + { + Name: "MySQL", + Func: storetest.NewMySQLContainer, + }, + { + Name: "PostgreSQL", + Func: storetest.NewPostgreSQLContainer, + }, +} func StoreTest(t *testing.T, f func(*testing.T, store.Store)) { - if sqlStore == nil { - utils.TranslationsPreInit() - utils.LoadConfig("config.json") - utils.InitTranslations(utils.Cfg.LocalizationSettings) - sqlStore = store.NewLayeredStore(NewSqlSupplier(nil), nil, nil) + defer func() { + if err := recover(); err != nil { + tearDownStores() + panic(err) + } + }() + for _, st := range storeTypes { + st := st + t.Run(st.Name, func(t *testing.T) { f(t, st.Store) }) + } +} - sqlStore.MarkSystemRanUnitTests() +func initStores() { + defer func() { + if err := recover(); err != nil { + tearDownStores() + panic(err) + } + }() + var wg sync.WaitGroup + errCh := make(chan error, len(storeTypes)) + wg.Add(len(storeTypes)) + for _, st := range storeTypes { + st := st + go func() { + defer wg.Done() + container, settings, err := st.Func() + if err != nil { + errCh <- err + return + } + st.Container = container + st.Store = store.NewLayeredStore(NewSqlSupplier(*settings, nil), nil, nil) + st.Store.MarkSystemRanUnitTests() + }() + } + wg.Wait() + select { + case err := <-errCh: + panic(err) + default: } - f(t, sqlStore) +} + +var tearDownStoresOnce sync.Once + +func tearDownStores() { + tearDownStoresOnce.Do(func() { + var wg sync.WaitGroup + wg.Add(len(storeTypes)) + for _, st := range storeTypes { + st := st + go func() { + st.Store.Close() + st.Container.Stop() + wg.Done() + }() + } + wg.Wait() + }) +} + +func TestMain(m *testing.M) { + flag.Parse() + + utils.TranslationsPreInit() + utils.LoadConfig("config.json") + utils.InitTranslations(utils.Cfg.LocalizationSettings) + + status := 0 + + initStores() + defer func() { + tearDownStores() + os.Exit(status) + }() + + status = m.Run() } |