springboot 实现限流控制
warning:
这篇文章距离上次修改已过456天,其中的内容可能已经有所变动。
package com.jinw.cms.config;
import com.jinw.cms.aspectj.RateLimiterAspect;
import lombok.RequiredArgsConstructor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
@RequiredArgsConstructor
@Configuration
public class RateLimiterConfig {
private final RedisTemplate<String, Object> redisTempate;
@Bean
@ConditionalOnProperty(name = "jw.rate-limiter.enable", havingValue = "true")
public RateLimiterAspect rateLimitAspect() {
return new RateLimiterAspect(redisTempate, limitScript());
}
/**
* Lua限流脚本
*/
public DefaultRedisScript<Boolean> limitScript() {
DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
redisScript.setScriptText(" local key = KEYS[1] --限流KEY\n" +
" local limit = tonumber(ARGV[1]) --限流大小\n" +
" local expireTime = tonumber(ARGV[2]) --过期时间 单位/s\n" +
"\n" +
" local current = tonumber(redis.call('get', key) or \"0\")\n" +
" if current + 1 > limit then\n" +
" return false --当前值超过限流大小阈值\n" +
" end\n" +
" current = tonumber(redis.call('incr', key)) --请求数+1\n" +
" if current == 1 then\n" +
" redis.call('expire', key, expireTime) --设置过期时间\n" +
" end\n" +
" return true;");
redisScript.setResultType(Boolean.class);
return redisScript;
}
}
package com.jinw.cms.aspectj;
/**
* 限流类型
*
* @author ruoyi
*/
public enum LimitType {
/**
* 默认策略全局限流
*/
DEFAULT,
/**
* 根据请求者IP进行限流
*/
IP
}
package com.jinw.cms.aspectj.annotation;
import com.jinw.cms.aspectj.LimitType;
import com.jinw.cms.constants.ExtendConstants;
import java.lang.annotation.*;
/**
* 限流注解
*
* @author ruoyi
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流缓存key前缀
*/
public String prefix() default ExtendConstants.RATE_LIMIT_KEY;
/**
* 限流时间,单位秒
*/
public int expire() default 60;
/**
* 限流阈值,单位时间内的请求上限
*/
public int limit() default 100;
/**
* 限流类型
*/
public LimitType limitType() default LimitType.DEFAULT;
}
package com.jinw.cms.config;
import com.jinw.cms.aspectj.RateLimiterAspect;
import lombok.RequiredArgsConstructor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
@RequiredArgsConstructor
@Configuration
public class RateLimiterConfig {
private final RedisTemplate<String, Object> redisTempate;
@Bean
@ConditionalOnProperty(name = "jw.rate-limiter.enable", havingValue = "true")
public RateLimiterAspect rateLimitAspect() {
return new RateLimiterAspect(redisTempate, limitScript());
}
/**
* Lua限流脚本
*/
public DefaultRedisScript<Boolean> limitScript() {
DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
redisScript.setScriptText(" local key = KEYS[1] --限流KEY\n" +
" local limit = tonumber(ARGV[1]) --限流大小\n" +
" local expireTime = tonumber(ARGV[2]) --过期时间 单位/s\n" +
"\n" +
" local current = tonumber(redis.call('get', key) or \"0\")\n" +
" if current + 1 > limit then\n" +
" return false --当前值超过限流大小阈值\n" +
" end\n" +
" current = tonumber(redis.call('incr', key)) --请求数+1\n" +
" if current == 1 then\n" +
" redis.call('expire', key, expireTime) --设置过期时间\n" +
" end\n" +
" return true;");
redisScript.setResultType(Boolean.class);
return redisScript;
}
}
package com.ruoyi.common.extend.aspectj;
import java.lang.reflect.Method;
import java.util.List;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import com.ruoyi.common.exception.GlobalException;
import com.ruoyi.common.extend.annotation.RateLimiter;
import com.ruoyi.common.extend.enums.LimitType;
import com.ruoyi.common.extend.exception.RateLimiterErrorCode;
import com.ruoyi.common.utils.ServletUtils;
import lombok.RequiredArgsConstructor;
/**
* 限流处理
*/
@Aspect
@RequiredArgsConstructor
public class RateLimiterAspect {
private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);
private final RedisTemplate<String, Object> redisTemplate;
private final RedisScript<Boolean> limitScript;
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
int limit = rateLimiter.limit();
int expire = rateLimiter.expire();
try {
String combineKey = this.getCombineKey(rateLimiter, point);
List<String> keys = List.of(combineKey);
if (!redisTemplate.execute(this.limitScript, keys, limit, expire)) {
log.warn("限制请求'{}',缓存key'{}'", limit, combineKey);
throw RateLimiterErrorCode.RATE_LIMIT.exception();
}
} catch (GlobalException e) {
throw e;
} catch (Exception e) {
throw RateLimiterErrorCode.RATE_LIMIT_ERR.exception();
}
}
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.prefix());
if (rateLimiter.limitType() == LimitType.IP) {
stringBuffer.append(ServletUtils.getIpAddr(ServletUtils.getRequest())).append(".");
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append(".").append(method.getName());
return stringBuffer.toString();
}
}