diff options
Diffstat (limited to 'vendor/github.com/go-redis/redis/ring.go')
-rw-r--r-- | vendor/github.com/go-redis/redis/ring.go | 370 |
1 files changed, 204 insertions, 166 deletions
diff --git a/vendor/github.com/go-redis/redis/ring.go b/vendor/github.com/go-redis/redis/ring.go index 10f33ed00..6d2877413 100644 --- a/vendor/github.com/go-redis/redis/ring.go +++ b/vendor/github.com/go-redis/redis/ring.go @@ -1,6 +1,7 @@ package redis import ( + "context" "errors" "fmt" "math/rand" @@ -15,6 +16,8 @@ import ( "github.com/go-redis/redis/internal/pool" ) +const nreplicas = 100 + var errRingShardsDown = errors.New("redis: all ring shards are down") // RingOptions are used to configure a ring client and should be @@ -85,6 +88,8 @@ func (opt *RingOptions) clientOptions() *Options { } } +//------------------------------------------------------------------------------ + type ringShard struct { Client *Client down int32 @@ -125,6 +130,150 @@ func (shard *ringShard) Vote(up bool) bool { return shard.IsDown() } +//------------------------------------------------------------------------------ + +type ringShards struct { + mu sync.RWMutex + hash *consistenthash.Map + shards map[string]*ringShard // read only + list []*ringShard // read only + closed bool +} + +func newRingShards() *ringShards { + return &ringShards{ + hash: consistenthash.New(nreplicas, nil), + shards: make(map[string]*ringShard), + } +} + +func (c *ringShards) Add(name string, cl *Client) { + shard := &ringShard{Client: cl} + c.hash.Add(name) + c.shards[name] = shard + c.list = append(c.list, shard) +} + +func (c *ringShards) List() []*ringShard { + c.mu.RLock() + list := c.list + c.mu.RUnlock() + return list +} + +func (c *ringShards) Hash(key string) string { + c.mu.RLock() + hash := c.hash.Get(key) + c.mu.RUnlock() + return hash +} + +func (c *ringShards) GetByKey(key string) (*ringShard, error) { + key = hashtag.Key(key) + + c.mu.RLock() + + if c.closed { + c.mu.RUnlock() + return nil, pool.ErrClosed + } + + hash := c.hash.Get(key) + if hash == "" { + c.mu.RUnlock() + return nil, errRingShardsDown + } + + shard := c.shards[hash] + c.mu.RUnlock() + + return shard, nil +} + +func (c *ringShards) GetByHash(name string) (*ringShard, error) { + if name == "" { + return c.Random() + } + + c.mu.RLock() + shard := c.shards[name] + c.mu.RUnlock() + return shard, nil +} + +func (c *ringShards) Random() (*ringShard, error) { + return c.GetByKey(strconv.Itoa(rand.Int())) +} + +// heartbeat monitors state of each shard in the ring. +func (c *ringShards) Heartbeat(frequency time.Duration) { + ticker := time.NewTicker(frequency) + defer ticker.Stop() + for range ticker.C { + var rebalance bool + + c.mu.RLock() + + if c.closed { + c.mu.RUnlock() + break + } + + shards := c.list + c.mu.RUnlock() + + for _, shard := range shards { + err := shard.Client.Ping().Err() + if shard.Vote(err == nil || err == pool.ErrPoolTimeout) { + internal.Logf("ring shard state changed: %s", shard) + rebalance = true + } + } + + if rebalance { + c.rebalance() + } + } +} + +// rebalance removes dead shards from the Ring. +func (c *ringShards) rebalance() { + hash := consistenthash.New(nreplicas, nil) + for name, shard := range c.shards { + if shard.IsUp() { + hash.Add(name) + } + } + + c.mu.Lock() + c.hash = hash + c.mu.Unlock() +} + +func (c *ringShards) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil + } + c.closed = true + + var firstErr error + for _, shard := range c.shards { + if err := shard.Client.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + c.hash = nil + c.shards = nil + c.list = nil + + return firstErr +} + +//------------------------------------------------------------------------------ + // Ring is a Redis client that uses constistent hashing to distribute // keys across multiple Redis servers (shards). It's safe for // concurrent use by multiple goroutines. @@ -142,33 +291,22 @@ func (shard *ringShard) Vote(up bool) bool { type Ring struct { cmdable - opt *RingOptions - nreplicas int + ctx context.Context - mu sync.RWMutex - hash *consistenthash.Map - shards map[string]*ringShard - shardsList []*ringShard + opt *RingOptions + shards *ringShards + cmdsInfoCache *cmdsInfoCache processPipeline func([]Cmder) error - - cmdsInfoOnce internal.Once - cmdsInfo map[string]*CommandInfo - - closed bool } func NewRing(opt *RingOptions) *Ring { - const nreplicas = 100 - opt.init() ring := &Ring{ - opt: opt, - nreplicas: nreplicas, - - hash: consistenthash.New(nreplicas, nil), - shards: make(map[string]*ringShard), + opt: opt, + shards: newRingShards(), + cmdsInfoCache: newCmdsInfoCache(), } ring.processPipeline = ring.defaultProcessPipeline ring.cmdable.setProcessor(ring.Process) @@ -176,21 +314,33 @@ func NewRing(opt *RingOptions) *Ring { for name, addr := range opt.Addrs { clopt := opt.clientOptions() clopt.Addr = addr - ring.addShard(name, NewClient(clopt)) + ring.shards.Add(name, NewClient(clopt)) } - go ring.heartbeat() + go ring.shards.Heartbeat(opt.HeartbeatFrequency) return ring } -func (c *Ring) addShard(name string, cl *Client) { - shard := &ringShard{Client: cl} - c.mu.Lock() - c.hash.Add(name) - c.shards[name] = shard - c.shardsList = append(c.shardsList, shard) - c.mu.Unlock() +func (c *Ring) Context() context.Context { + if c.ctx != nil { + return c.ctx + } + return context.Background() +} + +func (c *Ring) WithContext(ctx context.Context) *Ring { + if ctx == nil { + panic("nil context") + } + c2 := c.copy() + c2.ctx = ctx + return c2 +} + +func (c *Ring) copy() *Ring { + cp := *c + return &cp } // Options returns read-only Options that were used to create the client. @@ -204,10 +354,7 @@ func (c *Ring) retryBackoff(attempt int) time.Duration { // PoolStats returns accumulated connection pool stats. func (c *Ring) PoolStats() *PoolStats { - c.mu.RLock() - shards := c.shardsList - c.mu.RUnlock() - + shards := c.shards.List() var acc PoolStats for _, shard := range shards { s := shard.Client.connPool.Stats() @@ -226,7 +373,7 @@ func (c *Ring) Subscribe(channels ...string) *PubSub { panic("at least one channel is required") } - shard, err := c.shardByKey(channels[0]) + shard, err := c.shards.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) @@ -240,7 +387,7 @@ func (c *Ring) PSubscribe(channels ...string) *PubSub { panic("at least one channel is required") } - shard, err := c.shardByKey(channels[0]) + shard, err := c.shards.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) @@ -251,10 +398,7 @@ func (c *Ring) PSubscribe(channels ...string) *PubSub { // ForEachShard concurrently calls the fn on each live shard in the ring. // It returns the first error if any. func (c *Ring) ForEachShard(fn func(client *Client) error) error { - c.mu.RLock() - shards := c.shardsList - c.mu.RUnlock() - + shards := c.shards.List() var wg sync.WaitGroup errCh := make(chan error, 1) for _, shard := range shards { @@ -285,81 +429,38 @@ func (c *Ring) ForEachShard(fn func(client *Client) error) error { } func (c *Ring) cmdInfo(name string) *CommandInfo { - err := c.cmdsInfoOnce.Do(func() error { - c.mu.RLock() - shards := c.shardsList - c.mu.RUnlock() - - var firstErr error + cmdsInfo, err := c.cmdsInfoCache.Do(func() (map[string]*CommandInfo, error) { + shards := c.shards.List() + firstErr := errRingShardsDown for _, shard := range shards { cmdsInfo, err := shard.Client.Command().Result() if err == nil { - c.cmdsInfo = cmdsInfo - return nil + return cmdsInfo, nil } if firstErr == nil { firstErr = err } } - return firstErr + return nil, firstErr }) if err != nil { return nil } - if c.cmdsInfo == nil { - return nil - } - info := c.cmdsInfo[name] + info := cmdsInfo[name] if info == nil { internal.Logf("info for cmd=%s not found", name) } return info } -func (c *Ring) shardByKey(key string) (*ringShard, error) { - key = hashtag.Key(key) - - c.mu.RLock() - - if c.closed { - c.mu.RUnlock() - return nil, pool.ErrClosed - } - - name := c.hash.Get(key) - if name == "" { - c.mu.RUnlock() - return nil, errRingShardsDown - } - - shard := c.shards[name] - c.mu.RUnlock() - return shard, nil -} - -func (c *Ring) randomShard() (*ringShard, error) { - return c.shardByKey(strconv.Itoa(rand.Int())) -} - -func (c *Ring) shardByName(name string) (*ringShard, error) { - if name == "" { - return c.randomShard() - } - - c.mu.RLock() - shard := c.shards[name] - c.mu.RUnlock() - return shard, nil -} - func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) { cmdInfo := c.cmdInfo(cmd.Name()) pos := cmdFirstKeyPos(cmd, cmdInfo) if pos == 0 { - return c.randomShard() + return c.shards.Random() } firstKey := cmd.stringArg(pos) - return c.shardByKey(firstKey) + return c.shards.GetByKey(firstKey) } func (c *Ring) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error) { @@ -378,77 +479,6 @@ func (c *Ring) Process(cmd Cmder) error { return shard.Client.Process(cmd) } -// rebalance removes dead shards from the Ring. -func (c *Ring) rebalance() { - hash := consistenthash.New(c.nreplicas, nil) - for name, shard := range c.shards { - if shard.IsUp() { - hash.Add(name) - } - } - - c.mu.Lock() - c.hash = hash - c.mu.Unlock() -} - -// heartbeat monitors state of each shard in the ring. -func (c *Ring) heartbeat() { - ticker := time.NewTicker(c.opt.HeartbeatFrequency) - defer ticker.Stop() - for range ticker.C { - var rebalance bool - - c.mu.RLock() - - if c.closed { - c.mu.RUnlock() - break - } - - shards := c.shardsList - c.mu.RUnlock() - - for _, shard := range shards { - err := shard.Client.Ping().Err() - if shard.Vote(err == nil || err == pool.ErrPoolTimeout) { - internal.Logf("ring shard state changed: %s", shard) - rebalance = true - } - } - - if rebalance { - c.rebalance() - } - } -} - -// Close closes the ring client, releasing any open resources. -// -// It is rare to Close a Ring, as the Ring is meant to be long-lived -// and shared between many goroutines. -func (c *Ring) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.closed { - return nil - } - c.closed = true - - var firstErr error - for _, shard := range c.shards { - if err := shard.Client.Close(); err != nil && firstErr == nil { - firstErr = err - } - } - c.hash = nil - c.shards = nil - c.shardsList = nil - - return firstErr -} - func (c *Ring) Pipeline() Pipeliner { pipe := Pipeline{ exec: c.processPipeline, @@ -471,11 +501,11 @@ func (c *Ring) defaultProcessPipeline(cmds []Cmder) error { cmdsMap := make(map[string][]Cmder) for _, cmd := range cmds { cmdInfo := c.cmdInfo(cmd.Name()) - name := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) - if name != "" { - name = c.hash.Get(hashtag.Key(name)) + hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) + if hash != "" { + hash = c.shards.Hash(hashtag.Key(hash)) } - cmdsMap[name] = append(cmdsMap[name], cmd) + cmdsMap[hash] = append(cmdsMap[hash], cmd) } for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { @@ -485,8 +515,8 @@ func (c *Ring) defaultProcessPipeline(cmds []Cmder) error { var failedCmdsMap map[string][]Cmder - for name, cmds := range cmdsMap { - shard, err := c.shardByName(name) + for hash, cmds := range cmdsMap { + shard, err := c.shards.GetByHash(hash) if err != nil { setCmdsErr(cmds, err) continue @@ -509,7 +539,7 @@ func (c *Ring) defaultProcessPipeline(cmds []Cmder) error { if failedCmdsMap == nil { failedCmdsMap = make(map[string][]Cmder) } - failedCmdsMap[name] = cmds + failedCmdsMap[hash] = cmds } } @@ -529,3 +559,11 @@ func (c *Ring) TxPipeline() Pipeliner { func (c *Ring) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { panic("not implemented") } + +// Close closes the ring client, releasing any open resources. +// +// It is rare to Close a Ring, as the Ring is meant to be long-lived +// and shared between many goroutines. +func (c *Ring) Close() error { + return c.shards.Close() +} |