ThreadLocal之两则应用场景

ThreadLocal之两则应用场景

应用场景两则

  • 1、利用ThreadLocal管理登录用户信息实现随用随取
  • 2、可以记录Controller各个请求的执行时间

场景1spring boot实现步骤:

  • 1、新建实现WebMvcConfigurerAdapter的类,重写addInterceptors方法,registry添加实现了HandlerInterceptorAdapter类的拦截器,
    同时设置需要拦截与不需要拦截的路径正则表达式;
  • 2、在拦截器类中从session中取出用户信息,存入ThreadLocal中;
  • 3、引入spring-boot-starter-aop依赖,设置切入点(Pointcut)为controller中的各个方法,在AOP的后置最终通知(@After)中删除ThreadLocal中的信息
    注意:当然需要在登录时将登录用户信息存入session中

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
package com.yc.aop;

import com.yc.controller.UserController;
import com.yc.entity.UserEntity;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.*;
import org.springframework.stereotype.Component;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;

/**
* @author Yue Chang
* @version 1.0
* @className: UserEntityAspect
* @description: 用户信息设置切面类
* @date 2018年06月21日 19:27
*/
@Aspect
@Component
public class UserEntityAspect {

private static Logger logger = LoggerFactory.getLogger(UserEntityAspect.class);

private static ThreadLocal<UserEntity> threadLocal = new ThreadLocal<>();

@Pointcut("execution(public * com.yc.controller.*.*(..))")
public void userEntityPointcut(){}

@Before("userEntityPointcut()")
public void doBefore(JoinPoint joinPoint) throws Throwable {
// 接收到请求
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = attributes.getRequest();
// 从session里面获取对应的用户信息,设置到ThreadLocal中
Object object = request.getSession().getAttribute(UserController.USER_SESSION_KEY);
if (null != object && object instanceof UserEntity) {
UserEntity userEntity = (UserEntity) object;
threadLocal.set(userEntity);
}

}

@After(value = "userEntityPointcut()")
public void doAfter() throws Throwable {

// 访问结束,清除ThreadLocal中的值,避免产生OOM
threadLocal.remove();
}

/**
* 获取ThreadLocal中的用户信息
*/
public static UserEntity getUserEntity() {
return threadLocal.get();
}
}

场景2spring boot实现步骤:

  • 1、引入spring-boot-starter-aop依赖,设置切入点(Pointcut)为controller中的各个方法;
  • 2、在前通知中(@Before)中把调用开始时间设置到ThreadLocal中;
  • 3、在后置最终通知(@After)中获得调用时间信息并删除ThreadLocal中的信息

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package com.yc.aop;

import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.Arrays;

/**
* @author Yue Chang
* @version 1.0
* @className: UserEntityAspect
* @description: 时间记录切面类
* @date 2018年06月21日 19:27
*/
@Aspect
@Component
public class VisitTimeAspect {

private static Logger logger = LoggerFactory.getLogger(VisitTimeAspect.class);

private static ThreadLocal<Long> threadLocal = new ThreadLocal<Long>(){
// 设置初始值
@Override
protected Long initialValue() {
return System.currentTimeMillis();
}
};

@Pointcut("execution(public * com.yc.controller.*.*(..))")
public void visitTimePointcut(){}

@Before("visitTimePointcut()")
public void doBefore() throws Throwable {
// 设置开始访问时间
threadLocal.set(System.currentTimeMillis());
}

@After(value = "visitTimePointcut()")
public void doAfter(JoinPoint joinPoint) throws Throwable {

try {
// 接收到请求,记录请求内容
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = attributes.getRequest();

// 获得请求开始系统时间
Long startTime = threadLocal.get();
// 计算访问时间
double second = (System.currentTimeMillis() - startTime) / 1000.0;

// 记录下请求内容
logger.info("URL : " + request.getRequestURL().toString() + ", cost : " + second + "s ");
logger.info("HTTP_METHOD : " + request.getMethod());
logger.info("IP : " + request.getRemoteAddr());
logger.info("CLASS_METHOD : " + joinPoint.getSignature().getDeclaringTypeName() + "." + joinPoint.getSignature().getName());
logger.info("ARGS : " + Arrays.toString(joinPoint.getArgs()));
} finally {
// 访问结束,清除ThreadLocal中的值,避免产生OOM
threadLocal.remove();
}
}
}