首先TTL(transmittable-thread-local)是阿里开源用于解决线程池ThreadLocal的框架,详细介绍可以到官网查看。github开源地址
我写这篇文章主要是为了介绍一个冷门的使用场景————如何兼容二方三方包已有的ThreadLocal的场景。
我有一个二方包,用户管理用户信息代码如下,如果单线程模式,完全没有任何问题,我将用户封装到ThreadLocal对象中,需要时通过get方法进行获取。
public class UserContextUtil {
private static final ThreadLocal<User> currentUser = new ThreadLocal<>();
public static User getCurrentUser() {
return currentUser.get();
}
public static void setCurrentUser(User user) {
currentUser.set(user);
}
}
然而有个需求,需要不停调用IO操作且耗时严重,于是我引入线程池来解决这个问题。
这个时候就会发现通过UserContextUtil.getCurrentUser()
获取到的用户为null,因为ThreadLocal只是本地线程池对象,这个时候我们需要将ThreadLocal的内容进行传递到新的线程中。
Future<Object> submit = pool.submit(() -> {UserContextUtil.getCurrentUser()});
如果你已经查过资料应该可以找到以下解决方案:
InheritableThreadLocal:jdk提供的多线程ThreadLocal解决方案,但不能解决线程池的问题
TransmittableThreadLocal:阿里提供的线程池解决方案,实际也是基InheritableThreadLocal实现。再此基础上提供了TtlRunnable和TtlCallable两种包装类,在TtlRunnable和TtlCallable中实际上对ThreadLocal参数进行了capture,replay 和restore三个阶段
所以根据以上方案我们得出只要将UserContextUtil 中new ThreadLocal 改为new TransmittableThreadLocal即可,并且调用线程池时:
public class UserContextUtil {
private static final ThreadLocal<User> currentUser = new TransmittableThreadLocal<>();
public static User getCurrentUser() {
return currentUser.get();
}
public static void setCurrentUser(User user) {
currentUser.set(user);
}
}
-------------------------------------------------
Future<Object> submit = pool.submit(TtlCallable.get(() -> {UserContextUtil.getCurrentUser()}));
但是这样问题很明显那就是必须改掉原有代码UserContextUtil,如果这时候UserContextUtil代码并不是我们自己在维护而是引入三方库,那么这种方法将不可取。
首先说一下思路:
1、我们在开启线程池的时候将所有ThreadLocal 注册到TTL中。
(1. 获取所有线程
2、重写线程池,并重写submit、execute等方法,使调用时可以自动将Callable转为TtlCallable对象(TtlRunnable同理)
3、封装线程池工具类,统一线程创建入口
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlCallable;
import com.alibaba.ttl.TtlRunnable;
import java.lang.ref.WeakReference;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;
import java.util.stream.Collectors;
/**
* @author huibin.wei
* @date 2022/1/14 2:41 下午
*/
public class TtlThreadPoolExecutor extends ThreadPoolExecutor {
private Field threadLocalsField;
private Field tableField;
public TtlThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
}
public List<ThreadLocal> getThreadLocal() {
Thread thread = Thread.currentThread();
try {
Object threadLocalMap = getThreadLocalsField().get(thread);
Field tableField = getTableField();
Object[] table = (Object[]) tableField.get(threadLocalMap);
List<ThreadLocal> collect = Arrays.stream(table)
.filter(o -> o != null)
.map(entry -> ((WeakReference<ThreadLocal>) entry).get())
.filter(o -> o != null)
.collect(Collectors.toList());
return collect;
} catch (NoSuchFieldException e) {
throw new PlatformException(CommonErrorEnum.SYSTEM_ERROR, e);
} catch (IllegalAccessException e) {
throw new PlatformException(CommonErrorEnum.SYSTEM_ERROR, e);
}
}
private Field getThreadLocalsField() throws NoSuchFieldException {
if (threadLocalsField == null) {
synchronized (this) {
if (threadLocalsField == null) {
threadLocalsField = Thread.class.getDeclaredField("threadLocals");
threadLocalsField.setAccessible(true);
}
}
}
return threadLocalsField;
}
private Field getTableField() throws NoSuchFieldException, IllegalAccessException {
if (tableField == null) {
synchronized (this) {
if (tableField == null) {
tableField = getThreadLocalsField().get(Thread.currentThread()).getClass().getDeclaredField("table");
tableField.setAccessible(true);
}
}
}
return tableField;
}
@Override
public void execute(Runnable command) {
List<ThreadLocal> local = getThreadLocal();
for (ThreadLocal threadLocal : local) {
TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
}
super.execute(TtlRunnable.get(command));
}
@Override
public Future<?> submit(Runnable task) {
List<ThreadLocal> local = getThreadLocal();
for (ThreadLocal threadLocal : local) {
TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
}
return super.submit(TtlRunnable.get(task));
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
List<ThreadLocal> local = getThreadLocal();
for (ThreadLocal threadLocal : local) {
TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
}
return super.submit(TtlRunnable.get(task), result);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
List<ThreadLocal> local = getThreadLocal();
for (ThreadLocal threadLocal : local) {
TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
}
return super.submit(TtlCallable.get(task));
}
}