diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 7166208..96a5077 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -9,7 +9,7 @@ type ( Cache interface { Add(context.Context, string) error Remove(context.Context, string) error - Exists(context.Context, string) (bool, error) + Exists(context.Context, ...string) (bool, error) Size(context.Context) (int64, error) } diff --git a/internal/cache/redis.go b/internal/cache/redis.go index 616578f..afdf09d 100644 --- a/internal/cache/redis.go +++ b/internal/cache/redis.go @@ -40,8 +40,8 @@ func (c *redisCache) Remove(ctx context.Context, key string) error { return c.client.Do(ctx, cmd).Error() } -func (c *redisCache) Exists(ctx context.Context, key string) (bool, error) { - cmd := c.client.B().Exists().Key(key).Build() +func (c *redisCache) Exists(ctx context.Context, key ...string) (bool, error) { + cmd := c.client.B().Exists().Key(key...).Build() res, err := c.client.Do(ctx, cmd).AsBool() if err != nil { return false, err diff --git a/internal/cache/xmap.go b/internal/cache/xmap.go index 49ba5a8..3a10bf2 100644 --- a/internal/cache/xmap.go +++ b/internal/cache/xmap.go @@ -26,9 +26,14 @@ func (c *mapCache) Remove(_ context.Context, key string) error { return nil } -func (c *mapCache) Exists(_ context.Context, key string) (bool, error) { - _, ok := c.xmap.Load(key) - return ok, nil +func (c *mapCache) Exists(_ context.Context, key ...string) (bool, error) { + for _, v := range key { + _, ok := c.xmap.Load(v) + if ok { + return true, nil + } + } + return false, nil } func (c *mapCache) Size(_ context.Context) (int64, error) { diff --git a/internal/handler/token_transfer.go b/internal/handler/token_transfer.go index 182e00f..7ae360a 100644 --- a/internal/handler/token_transfer.go +++ b/internal/handler/token_transfer.go @@ -147,16 +147,10 @@ func (hc *HandlerContainer) checkStables(ctx context.Context, from string, to st return true, nil } - // TODO: Pipeline this check on Redis with a new method - fromExists, err := hc.cache.Exists(ctx, from) + exists, err := hc.cache.Exists(ctx, from, to) if err != nil { return false, err } - toExists, err := hc.cache.Exists(ctx, to) - if err != nil { - return false, err - } - - return fromExists || toExists, nil + return exists, nil }