懒惰篇,不想写字,直接贴代码
采用拦截器的方式拦截请求并计算,拦截规则是在X时间段
内收到超过Y次请求
,就判定为攻击,请求又分为正常请求
和异常请求
。
这里为了简化组件引入成本,采用内存存储
的方式,如果有需要可以改为其他存储方式(如redis)。请求次数采用本地缓存(guava)+自定义的滑动窗口(看前面的文章)
来存储和计算。
@Slf4j
public class AttackInterceptor extends HandlerInterceptorAdapter {
private Limiter limiter;
private ConcurrentHashSet<String> blockIps;
@Value("#{new Boolean('${web.attackInterceptor.allowTimes.enable:false}')}")
private boolean allowTimesAttackInterceptor;
@Value("#{new Boolean('${web.attackInterceptor.errorAllowTimes.enable:true}')}")
private boolean errorAllowTimesAttackInterceptor;
@Value("${web.attackInterceptor.windowIntervalInMs:1000}")
private int windowIntervalInMs;
@Value("${web.attackInterceptor.allowTimes:20}")
private int allowTimes;
@Value("${web.attackInterceptor.errorAllowTimes:10}")
private int errorAllowTimes;
@Value("#{'${web.attackInterceptor.excludeFilterPath:}'.isEmpty() ? null : '${web.attackInterceptor.excludeFilterPath:}'.split(',')}")
private List<String> excludeFilterPath;
@PostConstruct
public void init() {
limiter = new Limiter(windowIntervalInMs);
blockIps = new ConcurrentHashSet<>();
if (excludeFilterPath == null) {
excludeFilterPath = new ArrayList<>();
}
excludeFilterPath.add("/removeAttackBlackIp");
excludeFilterPath.add("/addAttackBlackIp");
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
Object handler) {
String ip = HttpIpUtils.getRealIP(request);
try {
if (doLimited(request, response)) {
return false;
}
if (allowTimesAttackInterceptor && !limiter.get(ip).tryAcquire(AttackType.REQ, allowTimes)) {
log.error("[AttackInterceptor]{}每秒请求超过限制次数,将被限制请求!", ip);
blockIps.add(ip);
}
if (doLimited(request, response)) {
return false;
}
} catch (Exception e) {
log.error("AttackInterceptor preHandle [url:{}, ip:{}] error!", request.getRequestURL().toString(), ip, e);
}
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler,
Exception ex) {
String ip = HttpIpUtils.getRealIP(request);
try {
int responseStatus = response.getStatus();
if (!HttpStatus.valueOf(responseStatus).isError()) {
return;
}
if (errorAllowTimesAttackInterceptor && !limiter.get(ip)
.tryAcquire(AttackType.ERROR_REQ, errorAllowTimes - 1)) {
log.error("[AttackInterceptor]{}每秒返回异常的请求超过限制次数,将被限制请求!", ip);
blockIps.add(ip);
}
} catch (Exception e) {
log.error("AttackInterceptor afterCompletion [url:{}, ip:{}] error!", request.getRequestURL().toString(),
ip, e);
}
}
/**
* @param request
* @param response
* @return boolean
* @author
* @date
*/
private boolean doLimited(HttpServletRequest request, HttpServletResponse response) {
if (!allowTimesAttackInterceptor && !errorAllowTimesAttackInterceptor) {
return false;
}
String uri = request.getRequestURI();
if (excludeFilterPath.contains(uri)) {
return false;
}
String ip = HttpIpUtils.getRealIP(request);
if (blockIps.contains(ip)) {
writeResult(response);
log.warn("[url:{}, ip:{}] is blocked by AttackInterceptor!", request.getRequestURL().toString(), ip);
return true;
}
return false;
}
/**
* @param response
* @author
* @date
*/
protected void writeResult(HttpServletResponse response) {
PrintWriter writer = null;
response.setStatus(HttpStatus.FORBIDDEN.value());
response.setCharacterEncoding("UTF-8");
response.setContentType("application/json;charset=UTF-8");
try {
writer = response.getWriter();
writer.print(JSON.toJSONString(ResponseResult.fail("blocked!")));
writer.flush();
} catch (Exception e) {
log.error("AttackInterceptor write blocked error!", e);
} finally {
if (writer != null) {
writer.close();
}
}
}
/**
* @param ip
* @return boolean
* @author
* @date
*/
public boolean removeBlockIp(String ip) {
return blockIps.remove(ip);
}
/**
* @param ip
* @return boolean
* @author
* @date
*/
public boolean addBlockIp(String ip) {
if (blockIps.contains(ip)) {
return false;
}
return blockIps.add(ip);
}
enum AttackType {
REQ,
ERROR_REQ
}
/**
* @author
* @return null
* @date
*/
class Limiter extends AbstractLocalCache<String, SlidingWindow> {
public static final int DEFAULT_CACHE_MAX_SIZE = 1000;
public static final int DEFAULT_CACHE_EXPIRE = 360;
private int windowIntervalInMs;
public Limiter(int windowIntervalInMs) {
super(CacheBuilder.newBuilder().maximumSize(DEFAULT_CACHE_MAX_SIZE)
.expireAfterAccess(DEFAULT_CACHE_EXPIRE, TimeUnit.SECONDS));
this.windowIntervalInMs = windowIntervalInMs;
}
@Override
protected SlidingWindow loadData(String key) {
return new SlidingWindow.SlidingWindowBuilder()
.ofEventTypeClass(AttackType.class)
.ofBucketCount(10)
.ofWindowIntervalInMs(windowIntervalInMs)
.build();
}
}
}
增加手动添加/删除IP黑名单的接口。为了安全起见,添加是否允许手动操作的开关,常规时间建议关闭开关
@RestController
@RequestMapping("")
public class WebController {
@Value("${web.attackInterceptor.canOpt:false}")
private boolean classOptAttackBlackIp;
@Autowired
private AttackInterceptor attackInterceptor;
/**
* @param ip
* @return java.lang.String
* @author
* @date
*/
@RequestMapping(value = "/removeAttackBlackIp",
method = {RequestMethod.GET, RequestMethod.POST})
public String removeAttackBlackIp(String ip) {
if (!classOptAttackBlackIp) {
return "禁止手动操作";
}
if (attackInterceptor.removeBlockIp(ip)) {
return ip + "已经从[AttackInterceptor]黑名单中移除!";
}
return ip + "不在[AttackInterceptor]黑名单中!";
}
/**
* @param ip
* @return java.lang.String
* @author
* @date
*/
@RequestMapping(value = "/addAttackBlackIp",
method = {RequestMethod.GET, RequestMethod.POST})
public String addAttackBlackIp(String ip) {
if (!classOptAttackBlackIp) {
return "禁止手动操作";
}
if (attackInterceptor.addBlockIp(ip)) {
return ip + "已添加到[AttackInterceptor]黑名单中!";
}
return ip + "已经存在于[AttackInterceptor]黑名单中!";
}
}
配置Bean
@Configuration
@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET)
@ConditionalOnClass({Servlet.class, DispatcherServlet.class, WebMvcConfigurer.class})
public class WebConfiguration implements WebMvcConfigurer {
@Bean
public AttackInterceptor getAttackInterceptor() {
return new AttackInterceptor();
}
@Override
public void addInterceptors(InterceptorRegistry registry) {
InterceptorRegistration registration = registry.addInterceptor(new LogInterceptor());
registration.excludePathPatterns("/error", "/druid/*", "/webjars/**", "/*.js", "/*.html",
"/*.img", "/swagger-resources/**", "/webjars/**", "/v2/**",
"/swagger-ui.html/**", "/csrf");
registry.addInterceptor(getAttackInterceptor());
}
}
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。