使用TTL实现ThreadLocal的copy——兼容二方三方包已有的ThreadLocal的场景

习胤运
2023-12-01

说明

首先TTL(transmittable-thread-local)是阿里开源用于解决线程池ThreadLocal的框架,详细介绍可以到官网查看。github开源地址
我写这篇文章主要是为了介绍一个冷门的使用场景————如何兼容二方三方包已有的ThreadLocal的场景。

案例:

我有一个二方包,用户管理用户信息代码如下,如果单线程模式,完全没有任何问题,我将用户封装到ThreadLocal对象中,需要时通过get方法进行获取。

public class UserContextUtil {
    private static final ThreadLocal<User> currentUser = new ThreadLocal<>();
    public static User getCurrentUser() {
        return currentUser.get();
    }
    public static void setCurrentUser(User user) {
        currentUser.set(user);
    }
}

存在问题

然而有个需求,需要不停调用IO操作且耗时严重,于是我引入线程池来解决这个问题。
这个时候就会发现通过UserContextUtil.getCurrentUser()获取到的用户为null,因为ThreadLocal只是本地线程池对象,这个时候我们需要将ThreadLocal的内容进行传递到新的线程中。

Future<Object> submit = pool.submit(() -> {UserContextUtil.getCurrentUser()});

常规方案

方案汇总

如果你已经查过资料应该可以找到以下解决方案
InheritableThreadLocal:jdk提供的多线程ThreadLocal解决方案,但不能解决线程池的问题
TransmittableThreadLocal:阿里提供的线程池解决方案,实际也是基InheritableThreadLocal实现。再此基础上提供了TtlRunnable和TtlCallable两种包装类,在TtlRunnable和TtlCallable中实际上对ThreadLocal参数进行了capture,replay 和restore三个阶段

  1. capture方法:抓取线程(线程A)的所有TTL值。
  2. replay方法:在另一个线程(线程B)中,回放在capture方法中抓取的TTL值,并返回 回放前TTL值的备份
  3. restore方法:恢复线程B执行replay方法之前的TTL值(即备份)

方案详情

所以根据以上方案我们得出只要将UserContextUtil 中new ThreadLocal 改为new TransmittableThreadLocal即可,并且调用线程池时:

public class UserContextUtil {
    private static final ThreadLocal<User> currentUser = new TransmittableThreadLocal<>();
    public static User getCurrentUser() {
        return currentUser.get();
    }
    public static void setCurrentUser(User user) {
        currentUser.set(user);
    }
}
-------------------------------------------------
Future<Object> submit = pool.submit(TtlCallable.get(() -> {UserContextUtil.getCurrentUser()}));

但是这样问题很明显那就是必须改掉原有代码UserContextUtil,如果这时候UserContextUtil代码并不是我们自己在维护而是引入三方库,那么这种方法将不可取。

二方三方包解决方案

首先说一下思路:
1、我们在开启线程池的时候将所有ThreadLocal 注册到TTL中。
(1. 获取所有线程
2、重写线程池,并重写submit、execute等方法,使调用时可以自动将Callable转为TtlCallable对象(TtlRunnable同理)
3、封装线程池工具类,统一线程创建入口

import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlCallable;
import com.alibaba.ttl.TtlRunnable;
import java.lang.ref.WeakReference;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;
import java.util.stream.Collectors;

/**
 * @author huibin.wei
 * @date 2022/1/14 2:41 下午
 */
public class TtlThreadPoolExecutor extends ThreadPoolExecutor {
    private Field threadLocalsField;
    private Field tableField;

    public TtlThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
    }


    public List<ThreadLocal> getThreadLocal() {
        Thread thread = Thread.currentThread();
        try {
            Object threadLocalMap = getThreadLocalsField().get(thread);
            Field tableField = getTableField();
            Object[] table = (Object[]) tableField.get(threadLocalMap);
            List<ThreadLocal> collect = Arrays.stream(table)
                    .filter(o -> o != null)
                    .map(entry -> ((WeakReference<ThreadLocal>) entry).get())
                    .filter(o -> o != null)
                    .collect(Collectors.toList());
            return collect;
        } catch (NoSuchFieldException e) {
            throw new PlatformException(CommonErrorEnum.SYSTEM_ERROR, e);
        } catch (IllegalAccessException e) {
            throw new PlatformException(CommonErrorEnum.SYSTEM_ERROR, e);
        }
    }

    private Field getThreadLocalsField() throws NoSuchFieldException {
        if (threadLocalsField == null) {
            synchronized (this) {
                if (threadLocalsField == null) {
                    threadLocalsField = Thread.class.getDeclaredField("threadLocals");
                    threadLocalsField.setAccessible(true);
                }
            }
        }
        return threadLocalsField;
    }

    private Field getTableField() throws NoSuchFieldException, IllegalAccessException {
        if (tableField == null) {
            synchronized (this) {
                if (tableField == null) {
                    tableField = getThreadLocalsField().get(Thread.currentThread()).getClass().getDeclaredField("table");
                    tableField.setAccessible(true);
                }
            }
        }

        return tableField;
    }

    @Override
    public void execute(Runnable command) {
        List<ThreadLocal> local = getThreadLocal();

        for (ThreadLocal threadLocal : local) {
            TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
        }
        super.execute(TtlRunnable.get(command));

    }

    @Override
    public Future<?> submit(Runnable task) {
        List<ThreadLocal> local = getThreadLocal();

        for (ThreadLocal threadLocal : local) {
            TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
        }
        return super.submit(TtlRunnable.get(task));

    }

    @Override
    public <T> Future<T> submit(Runnable task, T result) {
        List<ThreadLocal> local = getThreadLocal();

        for (ThreadLocal threadLocal : local) {
            TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
        }
        return super.submit(TtlRunnable.get(task), result);

    }
    @Override
    public <T> Future<T> submit(Callable<T> task) {
        List<ThreadLocal> local = getThreadLocal();

        for (ThreadLocal threadLocal : local) {
            TransmittableThreadLocal.Transmitter.registerThreadLocalWithShadowCopier(threadLocal);
        }
        return super.submit(TtlCallable.get(task));

    }
}

 类似资料: