Skip to content

Commit 5998140

Browse files
author
liufangzhou.aaa
committed
RedisRateLimiter support sharding
1 parent ec07cb8 commit 5998140

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiter.java

+55-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.concurrent.atomic.AtomicBoolean;
25+
import java.util.concurrent.atomic.AtomicInteger;
2526

2627
import jakarta.validation.constraints.Min;
2728
import org.apache.commons.logging.Log;
@@ -49,6 +50,7 @@
4950
* @author Ronny Bräunlich
5051
* @author Denis Cutic
5152
* @author Andrey Muchnik
53+
* @author Fangzhou Liu
5254
*/
5355
@ConfigurationProperties("spring.cloud.gateway.redis-rate-limiter")
5456
public class RedisRateLimiter extends AbstractRateLimiter<RedisRateLimiter.Config> implements ApplicationContextAware {
@@ -100,6 +102,11 @@ public class RedisRateLimiter extends AbstractRateLimiter<RedisRateLimiter.Confi
100102
*/
101103
private boolean includeHeaders = true;
102104

105+
/**
106+
* A Round-Robin like index to select a virtual shard.
107+
*/
108+
private final AtomicInteger shardIndex = new AtomicInteger(0);
109+
103110
/**
104111
* The name of the header that returns number of remaining requests during the current
105112
* second.
@@ -146,12 +153,18 @@ public RedisRateLimiter(int defaultReplenishRate, int defaultBurstCapacity, int
146153
this.defaultConfig.setRequestedTokens(defaultRequestedTokens);
147154
}
148155

149-
static List<String> getKeys(String id, String routeId) {
156+
static List<String> getKeys(String id, String routeId, String shardId) {
150157
// use `{}` around keys to use Redis Key hash tags
151158
// this allows for using redis cluster
152159

153160
// Make a unique key per user and route.
154-
String prefix = "request_rate_limiter.{" + routeId + "." + id + "}.";
161+
String prefix;
162+
if (shardId != null) {
163+
prefix = "request_rate_limiter.{" + routeId + "." + id + "." + shardId + "}.";
164+
}
165+
else {
166+
prefix = "request_rate_limiter.{" + routeId + "." + id + "}.";
167+
}
155168

156169
// You need two Redis keys for Token Bucket.
157170
String tokenKey = prefix + "tokens";
@@ -237,16 +250,16 @@ public Mono<Response> isAllowed(String routeId, String id) {
237250
Config routeConfig = loadConfiguration(routeId);
238251

239252
// How many requests per second do you want a user to be allowed to do?
240-
int replenishRate = routeConfig.getReplenishRate();
253+
int replenishRate = getShardedReplenishRate(routeConfig);
241254

242255
// How much bursting do you want to allow?
243-
int burstCapacity = routeConfig.getBurstCapacity();
256+
int burstCapacity = getShardedBurstCapacity(routeConfig);
244257

245258
// How many tokens are requested per request?
246259
int requestedTokens = routeConfig.getRequestedTokens();
247260

248261
try {
249-
List<String> keys = getKeys(id, routeId);
262+
List<String> keys = getKeys(id, routeId, getShard(routeConfig.getShards()));
250263

251264
// The arguments to the LUA script. time() returns unixtime in seconds.
252265
List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "", "", requestedTokens + "");
@@ -306,6 +319,30 @@ public Map<String, String> getHeaders(Config config, Long tokensLeft) {
306319
return headers;
307320
}
308321

322+
String getShard(int shards) {
323+
if (shards > 0) {
324+
// Ignore signature bit of shardIndex to make sure always positive value.
325+
return String.valueOf((shardIndex.getAndIncrement() & Integer.MAX_VALUE) % shards);
326+
}
327+
return null;
328+
}
329+
330+
int getShardedReplenishRate(Config config) {
331+
int replenishRate = config.getReplenishRate();
332+
if (config.getShards() > 0) {
333+
replenishRate = replenishRate / config.getShards();
334+
}
335+
return replenishRate;
336+
}
337+
338+
int getShardedBurstCapacity(Config config) {
339+
int burstCapacity = config.getBurstCapacity();
340+
if (config.getShards() > 0) {
341+
burstCapacity = burstCapacity / config.getShards();
342+
}
343+
return burstCapacity;
344+
}
345+
309346
@Validated
310347
public static class Config {
311348

@@ -318,6 +355,9 @@ public static class Config {
318355
@Min(1)
319356
private int requestedTokens = 1;
320357

358+
@Min(0)
359+
private int shards = 0;
360+
321361
public int getReplenishRate() {
322362
return replenishRate;
323363
}
@@ -347,11 +387,21 @@ public Config setRequestedTokens(int requestedTokens) {
347387
return this;
348388
}
349389

390+
public int getShards() {
391+
return shards;
392+
}
393+
394+
public Config setShards(int shards) {
395+
this.shards = shards;
396+
return this;
397+
}
398+
350399
@Override
351400
public String toString() {
352401
return new ToStringCreator(this).append("replenishRate", replenishRate)
353402
.append("burstCapacity", burstCapacity)
354403
.append("requestedTokens", requestedTokens)
404+
.append("shards", shards)
355405
.toString();
356406

357407
}

spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiterTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
* @author Ronny Bräunlich
5454
* @author Denis Cutic
5555
* @author Andrey Muchnik
56+
* @author Fangzhou Liu
5657
*/
5758
@SpringBootTest(webEnvironment = RANDOM_PORT, properties = { "spring.cloud.gateway.function.enabled=false" })
5859
@DirtiesContext
@@ -165,8 +166,8 @@ public void redisRateLimiterWorksForZeroBurstCapacity() throws Exception {
165166

166167
@Test
167168
public void keysUseRedisKeyHashTags() {
168-
assertThat(RedisRateLimiter.getKeys("1", "routeId")).containsExactly("request_rate_limiter.{routeId.1}.tokens",
169-
"request_rate_limiter.{routeId.1}.timestamp");
169+
assertThat(RedisRateLimiter.getKeys("1", "routeId", null))
170+
.containsExactly("request_rate_limiter.{routeId.1}.tokens", "request_rate_limiter.{routeId.1}.timestamp");
170171
}
171172

172173
@Test

0 commit comments

Comments
 (0)