22
22
import java .util .List ;
23
23
import java .util .Map ;
24
24
import java .util .concurrent .atomic .AtomicBoolean ;
25
+ import java .util .concurrent .atomic .AtomicInteger ;
25
26
26
27
import jakarta .validation .constraints .Min ;
27
28
import org .apache .commons .logging .Log ;
49
50
* @author Ronny Bräunlich
50
51
* @author Denis Cutic
51
52
* @author Andrey Muchnik
53
+ * @author Fangzhou Liu
52
54
*/
53
55
@ ConfigurationProperties ("spring.cloud.gateway.redis-rate-limiter" )
54
56
public class RedisRateLimiter extends AbstractRateLimiter <RedisRateLimiter .Config > implements ApplicationContextAware {
@@ -100,6 +102,11 @@ public class RedisRateLimiter extends AbstractRateLimiter<RedisRateLimiter.Confi
100
102
*/
101
103
private boolean includeHeaders = true ;
102
104
105
+ /**
106
+ * A Round-Robin like index to select a virtual shard.
107
+ */
108
+ private final AtomicInteger shardIndex = new AtomicInteger (0 );
109
+
103
110
/**
104
111
* The name of the header that returns number of remaining requests during the current
105
112
* second.
@@ -146,12 +153,18 @@ public RedisRateLimiter(int defaultReplenishRate, int defaultBurstCapacity, int
146
153
this .defaultConfig .setRequestedTokens (defaultRequestedTokens );
147
154
}
148
155
149
- static List <String > getKeys (String id , String routeId ) {
156
+ static List <String > getKeys (String id , String routeId , String shardId ) {
150
157
// use `{}` around keys to use Redis Key hash tags
151
158
// this allows for using redis cluster
152
159
153
160
// Make a unique key per user and route.
154
- String prefix = "request_rate_limiter.{" + routeId + "." + id + "}." ;
161
+ String prefix ;
162
+ if (shardId != null ) {
163
+ prefix = "request_rate_limiter.{" + routeId + "." + id + "." + shardId + "}." ;
164
+ }
165
+ else {
166
+ prefix = "request_rate_limiter.{" + routeId + "." + id + "}." ;
167
+ }
155
168
156
169
// You need two Redis keys for Token Bucket.
157
170
String tokenKey = prefix + "tokens" ;
@@ -237,16 +250,16 @@ public Mono<Response> isAllowed(String routeId, String id) {
237
250
Config routeConfig = loadConfiguration (routeId );
238
251
239
252
// How many requests per second do you want a user to be allowed to do?
240
- int replenishRate = routeConfig . getReplenishRate ( );
253
+ int replenishRate = getShardedReplenishRate ( routeConfig );
241
254
242
255
// How much bursting do you want to allow?
243
- int burstCapacity = routeConfig . getBurstCapacity ( );
256
+ int burstCapacity = getShardedBurstCapacity ( routeConfig );
244
257
245
258
// How many tokens are requested per request?
246
259
int requestedTokens = routeConfig .getRequestedTokens ();
247
260
248
261
try {
249
- List <String > keys = getKeys (id , routeId );
262
+ List <String > keys = getKeys (id , routeId , getShard ( routeConfig . getShards ()) );
250
263
251
264
// The arguments to the LUA script. time() returns unixtime in seconds.
252
265
List <String > scriptArgs = Arrays .asList (replenishRate + "" , burstCapacity + "" , "" , requestedTokens + "" );
@@ -306,6 +319,30 @@ public Map<String, String> getHeaders(Config config, Long tokensLeft) {
306
319
return headers ;
307
320
}
308
321
322
+ String getShard (int shards ) {
323
+ if (shards > 0 ) {
324
+ // Ignore signature bit of shardIndex to make sure always positive value.
325
+ return String .valueOf ((shardIndex .getAndIncrement () & Integer .MAX_VALUE ) % shards );
326
+ }
327
+ return null ;
328
+ }
329
+
330
+ int getShardedReplenishRate (Config config ) {
331
+ int replenishRate = config .getReplenishRate ();
332
+ if (config .getShards () > 0 ) {
333
+ replenishRate = replenishRate / config .getShards ();
334
+ }
335
+ return replenishRate ;
336
+ }
337
+
338
+ int getShardedBurstCapacity (Config config ) {
339
+ int burstCapacity = config .getBurstCapacity ();
340
+ if (config .getShards () > 0 ) {
341
+ burstCapacity = burstCapacity / config .getShards ();
342
+ }
343
+ return burstCapacity ;
344
+ }
345
+
309
346
@ Validated
310
347
public static class Config {
311
348
@@ -318,6 +355,9 @@ public static class Config {
318
355
@ Min (1 )
319
356
private int requestedTokens = 1 ;
320
357
358
+ @ Min (0 )
359
+ private int shards = 0 ;
360
+
321
361
public int getReplenishRate () {
322
362
return replenishRate ;
323
363
}
@@ -347,11 +387,21 @@ public Config setRequestedTokens(int requestedTokens) {
347
387
return this ;
348
388
}
349
389
390
+ public int getShards () {
391
+ return shards ;
392
+ }
393
+
394
+ public Config setShards (int shards ) {
395
+ this .shards = shards ;
396
+ return this ;
397
+ }
398
+
350
399
@ Override
351
400
public String toString () {
352
401
return new ToStringCreator (this ).append ("replenishRate" , replenishRate )
353
402
.append ("burstCapacity" , burstCapacity )
354
403
.append ("requestedTokens" , requestedTokens )
404
+ .append ("shards" , shards )
355
405
.toString ();
356
406
357
407
}
0 commit comments