怎么在 Java 中限制用户访问频率

怎么在 Java 中限制用户访问频率

我们当然要限制用户访问频率,因为用户可能生气,并狂点我们的网站或应用。

他也可能很坏,使用一些爬虫试图拖垮我们的服务器。

所以怎么实现呢?

本文使用 springboot,并将用户的信息和访问频率记录到 redis 中,如果你没有使用 redis,也不影响,你可以参考着自己实现,比如存储到内存或数据库中。

想想这个需求,从第一性原理出发

用户可能没有登录,或者已经登录了。

如果用户登录了,我们就根据用户名来限制,否则,就根据IP或者其它设备码来限制,本文假设使用IP。

我们希望它足够简单,可以在多个方法上使用,而不需要编写额外的代码。所以我们要使用接口切面。

接口

 1public @interface RequestRateLimit {
 2
 3	/**
 4	 * 限流的key,比如限制用户注册,限制用户发送邮件,等等,一般是方法名
 5	 * @return
 6	 */
 7	String key() default "";
 8
 9	/**
10	 * 限流模式,默认单机
11	 * @return
12	 */
13	RateType type() default RateType.PER_CLIENT;
14
15	/**
16	 * 限流速率,1次/分钟
17	 * @return
18	 */
19	long rate() default 1;
20
21	/**
22	 * 限流速率,每分钟
23	 * @return
24	 */
25	long rateInterval() default 60 * 1000;
26
27	/**
28	 * 限流速率单位
29	 * @return
30	 */
31	RateIntervalUnit timeUnit() default RateIntervalUnit.MILLISECONDS;
32
33}

切面

你可以直接拷贝这些代码并测试。

 1public class RequestRateLimitAspect {
 2
 3	private RedissonClient redisson;
 4	private final UserService userService;
 5
 6	/**
 7	 * 根据自定义注解获取切点
 8	 *
 9	 * @param RequestRateLimit 注解接口
10	 */
11	@Pointcut("@annotation(RequestRateLimit)")
12	public void findAnnotationPointCut(RequestRateLimit RequestRateLimit) {
13	}
14
15	@Around(value = "findAnnotationPointCut(requestRateLimit)", argNames = "joinPoint,requestRateLimit")
16	public Object around(ProceedingJoinPoint joinPoint, RequestRateLimit requestRateLimit) throws Throwable {
17		UserEntity user = userService.getCurrentRequestUser(); // 只是封装了 SecurityContextHolder.getContext().getAuthentication().getPrincipal();
18		String realIp = "";
19		if (user == null) {
20			RequestAttributes ra = RequestContextHolder.getRequestAttributes();
21			ServletRequestAttributes sra = (ServletRequestAttributes) ra;
22			if (null != sra) {
23				HttpServletRequest request = sra.getRequest();
24				realIp = request.getHeader("His-Real-IP");
25				if (notValidIp(realIp)) {
26					realIp = request.getHeader("His-Real-IP2");
27					if (notValidIp(realIp)) {
28						realIp = request.getRemoteAddr();
29					}
30				}
31			}
32		}
33		if (user == null && notValidIp(realIp)) {
34			return R.failed(EMPTY_USER, "未找到您的任何登录信息");
35		}
36		// 限流拦截器
37		String key = user == null || StrUtil.isBlank(user.getUserName()) ? realIp : user.getUserName();
38		key = key + "::" + joinPoint.getSignature().getName();
39		RRateLimiter limiter = getRateLimiter(requestRateLimit, key);
40		if (limiter.tryAcquire(1)) {
41			return joinPoint.proceed();
42		} else {
43			log.info("rate-limit: {} {} {}", user == null ? "" : user.getUserName(), realIp, joinPoint.getSignature());
44			return R.failed(REACH_REQUEST_LIMIT, String.format("请求过于频繁,请于以下时间后重试:%s %s", requestRateLimit.rateInterval(), requestRateLimit.timeUnit().name().toLowerCase()));
45		}
46	}
47
48	private boolean notValidIp(String ip) {
49		return StrUtil.isBlank(ip) || ip.startsWith("172.1"); // docker bridge ip
50	}
51
52	/**
53	 * 获取限流拦截器
54	 *
55	 * @param limit  在要限流的方法上的配置
56	 * @param defaultKey 在redis中的存储的key
57	 * @return 限流器
58	 */
59	private RRateLimiter getRateLimiter(RequestRateLimit limit, String defaultKey) {
60		RRateLimiter rRateLimiter = redisson.getRateLimiter(StrUtil.isBlank(limit.key()) ? RATE_LIMITER + "::" + defaultKey : limit.key()); // RATE_LIMITER 随意起名,比如可以使用你的项目名称,只是为了在redis中好区分
61		// 设置限流
62		if (rRateLimiter.isExists()) {
63			RateLimiterConfig existed = rRateLimiter.getConfig();
64			// 判断配置是否更新,如果更新,重新加载限流器配置
65			if (!Objects.equals(limit.rate(), existed.getRate())
66					|| !Objects.equals(limit.timeUnit().toMillis(limit.rateInterval()), existed.getRateInterval())
67					|| !Objects.equals(limit.type(), existed.getRateType())) {
68				rRateLimiter.delete();
69				rRateLimiter.trySetRate(limit.type(), limit.rate(), limit.rateInterval(), limit.timeUnit());
70				expireByConfig(rRateLimiter, limit);
71			}
72		} else {
73			rRateLimiter.trySetRate(limit.type(), limit.rate(), limit.rateInterval(), limit.timeUnit());
74			expireByConfig(rRateLimiter, limit);
75		}
76
77		return rRateLimiter;
78	}
79
80	private void expireByConfig(RRateLimiter rRateLimiter, RequestRateLimit limit) {
81		// ttl 设置为 rateLimit 配置时间 + 5s
82		long limitDuration = limit.timeUnit().toMillis(limit.rateInterval()) + 5000;
83		// 设置过期时间,从现在算起 + 以上计算的时间。超时时间到后会删除一下几个key
84		// 1) "{rr_limiter::username}:value:***********"
85		// 2) "{rr_limiter::username}:permits:***********"
86		// 3) "rr_limiter::username"
87		rRateLimiter.expire(Instant.now().plusMillis(limitDuration));
88	}
89}

使用

1    @GetMapping("/info")
2	@RequestRateLimit(rate = 2, rateInterval = 1, timeUnit = RateIntervalUnit.MINUTES) // 1 分钟允许请求 2 次
3	public R getInfo() {
4        // ...
5    }

当用户请求 /info 接口的时候,redis 中就会存储一个 RATE_LIMITER::his_user_name::com.package.getInfo 这样的 key。当该用户在1分钟内请求该接口超过2次,那么他将会收到报错,并且 getInfo 方法并不会执行。

注意该注解无法作用于 @Cacheable 注释的方法上。

更多

我们可以实现一个自定义的频率限制,可以限制任意的方法,比如发送给运维人员的紧急邮件,如果同一主题发送过了,在5分钟内不要再次发送。

  1public @interface CustomRateLimit {
  2    /**
  3     * key 的前缀,用于一组相同功能限流的标记
  4     * @return
  5     */
  6    String prefix() default "";
  7
  8    /**
  9     * 限流的 key,要求不为空,支持从参数中读取
 10     * @return
 11     */
 12    String key() default "#key";
 13
 14    /**
 15     * 限流模式,默认单机
 16     * @return
 17     */
 18    RateType type() default RateType.PER_CLIENT;
 19
 20    /**
 21     * 限流速率,1次/分钟
 22     * @return
 23     */
 24    long rate() default 1;
 25
 26    /**
 27     * 限流速率,每分钟
 28     * @return
 29     */
 30    long rateInterval() default 60 * 1000;
 31
 32    /**
 33     * 限流速率单位
 34     * @return
 35     */
 36    RateIntervalUnit timeUnit() default RateIntervalUnit.MILLISECONDS;
 37
 38}
 39
 40public class CustomRateLimitAspect {
 41
 42	private final RedissonClient redisson;
 43	/**
 44	 * 根据自定义注解获取切点
 45	 *
 46	 * @param CustomRateLimit 注解接口
 47	 */
 48	@Pointcut("@annotation(CustomRateLimit)")
 49	public void findAnnotationPointCut(CustomRateLimit CustomRateLimit) {
 50	}
 51
 52	@Around(value = "findAnnotationPointCut(customRateLimit)", argNames = "joinPoint,customRateLimit")
 53	public Object around(ProceedingJoinPoint joinPoint, CustomRateLimit customRateLimit) throws Throwable {
 54		// 限流拦截器
 55		String key = getKey(joinPoint, customRateLimit);
 56		RRateLimiter limiter = getRateLimiter(customRateLimit, key);
 57		if (limiter.tryAcquire(1)) {
 58			return joinPoint.proceed();
 59		} else {
 60			log.info("skip method cause violate rate limit, key is {}", key);
 61			return R.failed(REACH_REQUEST_LIMIT, String.format("请求过于频繁,请于以下时间后重试:%s %s", customRateLimit.rateInterval(), customRateLimit.timeUnit().name().toLowerCase()));
 62		}
 63	}
 64
 65	/**
 66	 * 获取限流拦截器
 67	 *
 68	 * @param limit  在要限流的方法上的配置
 69	 * @return 限流器
 70	 */
 71	private RRateLimiter getRateLimiter(CustomRateLimit limit, String key) {
 72		RRateLimiter rRateLimiter = redisson.getRateLimiter(CUSTOM_RATE_LIMITER_PREFIX + "::" + limit.prefix() + "::" + key);
 73		// 设置限流
 74		if (rRateLimiter.isExists()) {
 75			RateLimiterConfig existed = rRateLimiter.getConfig();
 76			// 判断配置是否更新,如果更新,重新加载限流器配置
 77			if (!Objects.equals(limit.rate(), existed.getRate())
 78					|| !Objects.equals(limit.timeUnit().toMillis(limit.rateInterval()), existed.getRateInterval())
 79					|| !Objects.equals(limit.type(), existed.getRateType())) {
 80				rRateLimiter.delete();
 81				rRateLimiter.trySetRate(limit.type(), limit.rate(), limit.rateInterval(), limit.timeUnit());
 82				expireByConfig(rRateLimiter, limit);
 83			}
 84		} else {
 85			rRateLimiter.trySetRate(limit.type(), limit.rate(), limit.rateInterval(), limit.timeUnit());
 86			expireByConfig(rRateLimiter, limit);
 87		}
 88
 89		return rRateLimiter;
 90	}
 91
 92	private void expireByConfig(RRateLimiter rRateLimiter, CustomRateLimit limit) {
 93		long limitDuration = limit.timeUnit().toMillis(limit.rateInterval()) + 5000;
 94		rRateLimiter.expire(Instant.now().plusMillis(limitDuration));
 95	}
 96
 97	// el表达式支持
 98	private String getKey(JoinPoint joinPoint, CustomRateLimit customRateLimit) {
 99		ExpressionParser expressionParser = new SpelExpressionParser();
100		Expression expression = expressionParser.parseExpression(customRateLimit.key());
101		CodeSignature methodSignature = (CodeSignature) joinPoint.getSignature();
102		String[] sigParamNames = methodSignature.getParameterNames();
103		EvaluationContext context = new StandardEvaluationContext();
104		Object[] args = joinPoint.getArgs();
105		for (int i = 0; i < sigParamNames.length; i++) {
106			context.setVariable(sigParamNames[i], args[i]);
107		}
108		return (String) expression.getValue(context);
109	}
110}

使用

1	@Override
2	@CustomRateLimit(prefix = Constants.Cache.EMAIL_RATE_LIMITER, rateInterval = 5, timeUnit = RateIntervalUnit.MINUTES) // 5分钟最多一次
3	public void sendToMaintainersWithFrequencyLimit(String key, String subject, String... content) {
4		sendToMaintainers("[紧急通知]",  subject, content);
5	}