当前位置: 首页 > 工具软件 > Awaitility > 使用案例 >

异步操作校验工具awaitility源码分析

国胤
2023-12-01

1. 背景

之前介绍了一篇awaitility快速入门的文章:异步校验工具awaitility快速入门,该工具很好地解决了校验异步操作的问题,其中封装了很多便捷的用法,重点是在规定时间内,轮询结果;本文以源码的方式,介绍一下工具内部是怎么实现的,理解开发的设计思路,对以后解决工作中的问题是有帮助的。

2. 核心源码梳理

2-1. 例子

    // 异步任务,每隔1s, count累加1
    class CounterServiceImpl implements CounterService {
        private volatile int count = 0;

        public void run() {
            new Thread(new Runnable() {
                @Override
                public void run() {
                    try {
                        for (int index = 0; index < 5; index++) {
                            Thread.sleep(1000);
                            count += 1;
                        }
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                }
            }).start();
        }

        public int getCount() {
            return count;
        }
    }

    @Test
    public void testAsynchronousPoll() {
        final CounterService service = new CounterServiceImpl();
        service.run();

       // 轮询查询,pollInterval每隔多少时间段轮询,pollDelay延迟校验时间
        with().atMost(60, SECONDS).and().pollInterval(ONE_HUNDRED_MILLISECONDS).and().pollDelay(50, MILLISECONDS).await("count is greater 6").until(
                new Callable<Boolean>() {
                    @Override
                    public Boolean call() throws Exception {
                        return service.getCount() == 6;
                    }
                });
    }

2-2. 核心源码分析

例子中:

// 设置超时时间,6s
atMost(6, SECONDS)

// 设置间隔100ms
pollInterval(ONE_HUNDRED_MILLISECONDS)

// 设置延迟50ms
pollDelay(50, MILLISECONDS)

// 设置提示语
await("count is greater 6")

// 连接
and()

// 等待java.util.concurrent.Callable返回true
until(
    new Callable<Boolean>() {
    @Override
    public Boolean call() throws Exception {
        return service.getCount() == 6;
    }
});

ConditionFactory类所在包路径:

package org.awaitility.core

ConditionFactory类是一个Condition工厂,with()会新建一个ConditionFactory实例:

public static ConditionFactory with() {
        return new ConditionFactory(defaultWaitConstraint, defaultPollInterval, defaultPollDelay,
                defaultCatchUncaughtExceptions, defaultExceptionIgnorer, defaultConditionEvaluationListener);
}

ConditionFactory带参构造函数如下:

/**
     * Instantiates a new condition factory.
     *
     * @param timeoutConstraint       the timeout
     * @param pollInterval            the poll interval
     * @param pollDelay               The delay before the polling starts
     * @param exceptionsIgnorer       the ignore exceptions
     * @param catchUncaughtExceptions the catch uncaught exceptions
     */
    public ConditionFactory(WaitConstraint timeoutConstraint, PollInterval pollInterval, Duration pollDelay,
                            boolean catchUncaughtExceptions, ExceptionIgnorer exceptionsIgnorer,
                            ConditionEvaluationListener conditionEvaluationListener) {
        this(null, timeoutConstraint, pollInterval, pollDelay, catchUncaughtExceptions, exceptionsIgnorer,
                conditionEvaluationListener);
    }

构造函数的参数值:

/**
     * The default poll interval (fixed 100 ms).
     */
    private static volatile PollInterval defaultPollInterval = DEFAULT_POLL_INTERVAL;

    /**
     * The default wait constraint (10 seconds).
     */
    private static volatile WaitConstraint defaultWaitConstraint = AtMostWaitConstraint.TEN_SECONDS;

    /**
     * The default poll delay
     */
    private static volatile Duration defaultPollDelay = DEFAULT_POLL_DELAY;

    /**
     * Catch all uncaught exceptions by default?
     */
    private static volatile boolean defaultCatchUncaughtExceptions = true;

    /**
     * Ignore caught exceptions by default?
     */
    private static volatile ExceptionIgnorer defaultExceptionIgnorer = new PredicateExceptionIgnorer(new Predicate<Exception>() {
        public boolean matches(Exception e) {
            return false;
        }
    });

    /**
     * Default listener of condition evaluation results.
     */
    private static volatile ConditionEvaluationListener defaultConditionEvaluationListener = null;

通过这些参数可知,默认的超时时间为10s, 轮询间隔为100ms;

通过ConditionFactory的成员方法,对成员变量进行替换,比如atMost(60, SECONDS):

public ConditionFactory atMost(Duration timeout) {
        return new ConditionFactory(alias, timeoutConstraint.withMaxWaitTime(timeout), pollInterval, pollDelay,
                catchUncaughtExceptions, exceptionsIgnorer, conditionEvaluationListener);
}

将final成员变量this.timeoutConstraint替换成用户设置的new Duration(60, SECONDS);

and()返回this;

当用户再设置pollInterval(ONE_HUNDRED_MILLISECONDS)时:

public ConditionFactory pollInterval(Duration pollInterval) {
        return new ConditionFactory(alias, timeoutConstraint, pollInterval, pollDelay, catchUncaughtExceptions,
                exceptionsIgnorer, conditionEvaluationListener);
    }

会重新new一个对象,ONE_HUNDRED_MILLISECONDS会赋给final成员变量:pollInterval,之前赋过的值保留,比如:timeoutConstraint还是上一步设置的new Duration(60, SECONDS);依次类推new Duration(50, MILLISECONDS)赋给final成员变量:pollDelay; “count is greater 6”赋给final成员变量:alias。

最终实例化的ConditionFactory对象中,成员变量为:

timeoutConstraint为60s
pollInterval为100ms
pollDelay为50ms
alias为"count is greater 6"

ContiditionFactory类成员方法until():

public void until(Callable<Boolean> conditionEvaluator) {
        until(new CallableCondition(conditionEvaluator, generateConditionSettings()));
}

private <T> T until(Condition<T> condition) {
        return condition.await();
    }

generateConditionSettings()将ConditionFactory的final变量赋值给javabean对象ConditionSettings:

new ConditionSettings(alias, catchUncaughtExceptions, timeoutConstraint, pollInterval, actualPollDelay,
                conditionEvaluationListener, exceptionsIgnorer);

实例化CallableCondition类,构造函数:

public CallableCondition(final Callable<Boolean> matcher, ConditionSettings settings) {
        conditionEvaluationHandler = new ConditionEvaluationHandler<Object>(null, settings);
        ConditionEvaluationWrapper conditionEvaluationWrapper = new ConditionEvaluationWrapper(matcher, settings, conditionEvaluationHandler);
        conditionAwaiter = new ConditionAwaiter(conditionEvaluationWrapper, settings) {
            @SuppressWarnings("rawtypes")
            @Override
            protected String getTimeoutMessage() {
                if (timeout_message != null) {
                    return timeout_message;
                }
                final String timeoutMessage;
                if (matcher == null) {
                    timeoutMessage = "";
                } else {
                    final Class<? extends Callable> type = matcher.getClass();
                    final Method enclosingMethod = type.getEnclosingMethod();
                    if (type.isAnonymousClass() && enclosingMethod != null) {
                        timeoutMessage = String.format("Condition returned by method \"%s\" in class %s was not fulfilled",
                                enclosingMethod.getName(), enclosingMethod.getDeclaringClass().getName());
                    } else {
                        final String message;
                        if (isLambdaClass(type)) {
                            message = "with " + generateLambdaErrorMessagePrefix(type, false);
                        } else {
                            message = type.getName();
                        }
                        timeoutMessage = String.format("Condition %s was not fulfilled", message);
                    }
                }
                return timeoutMessage;
            }
        };
    }

同时实例化ConditionAwaiter对象,ConditionAwaiter构造函数:

 public ConditionAwaiter(final ConditionEvaluator conditionEvaluator,
                            final ConditionSettings conditionSettings) {
        if (conditionEvaluator == null) {
            throw new IllegalArgumentException("You must specify a condition (was null).");
        }
        if (conditionSettings == null) {
            throw new IllegalArgumentException("You must specify the condition settings (was null).");
        }
        if (conditionSettings.shouldCatchUncaughtExceptions()) {
            Thread.setDefaultUncaughtExceptionHandler(this);
        }
        this.conditionSettings = conditionSettings;
        this.latch = new CountDownLatch(1);
        this.conditionEvaluator = conditionEvaluator;
        this.executor = initExecutorService();
    }

并调用CallableCondition实例的await()方法:

 public Void await() {
        conditionAwaiter.await(conditionEvaluationHandler);
        return null;
    }

接着调用ConditionAwaiter实例的await():

public <T> void await(final ConditionEvaluationHandler<T> conditionEvaluationHandler) {
        final Duration pollDelay = conditionSettings.getPollDelay();
        final Duration maxWaitTime = conditionSettings.getMaxWaitTime();
        final Duration minWaitTime = conditionSettings.getMinWaitTime();

        final long maxTimeout = maxWaitTime.getValue();
        final TimeUnit maxTimeoutUnit = maxWaitTime.getTimeUnit();

        long pollingStarted = System.currentTimeMillis() - pollDelay.getValueInMS();
        pollSchedulingThread(conditionEvaluationHandler, pollDelay, maxWaitTime).start();

        try {
            try {
                final boolean finishedBeforeTimeout;
                if (maxWaitTime == Duration.FOREVER) {
                    latch.await();
                    finishedBeforeTimeout = true;
                } else {
                    finishedBeforeTimeout = latch.await(maxTimeout, maxTimeoutUnit);
                }

                Duration evaluationDuration =
                        new Duration(System.currentTimeMillis() - pollingStarted, TimeUnit.MILLISECONDS)
                                .minus(pollDelay);

                if (throwable != null) {
                    throw throwable;
                } else if (!finishedBeforeTimeout) {
                    final String maxWaitTimeLowerCase = maxWaitTime.getTimeUnitAsString();
                    final String message;
                    if (conditionSettings.hasAlias()) {
                        message = String.format("Condition with alias '%s' didn't complete within %s %s because %s.",
                                conditionSettings.getAlias(), maxTimeout, maxWaitTimeLowerCase, Introspector.decapitalize(getTimeoutMessage()));
                    } else {
                        message = String.format("%s within %s %s.", getTimeoutMessage(), maxTimeout, maxWaitTimeLowerCase);
                    }

                    final ConditionTimeoutException e;

                    // Not all systems support deadlock detection so ignore if ThreadMXBean & ManagementFactory is not in classpath
                    if (existInCP("java.lang.management.ThreadMXBean") && existInCP("java.lang.management.ManagementFactory")) {
                        java.lang.management.ThreadMXBean bean = java.lang.management.ManagementFactory.getThreadMXBean();
                        Throwable cause = this.cause;
                        try {
                            long[] threadIds = bean.findDeadlockedThreads();
                            if (threadIds != null) {
                                cause = new DeadlockException(threadIds);
                            }
                        } catch (UnsupportedOperationException ignored) {
                            // findDeadLockedThreads() not supported on this VM,
                            // don't init cause and move on.
                        }
                        e = new ConditionTimeoutException(message, cause);
                    } else {
                        e = new ConditionTimeoutException(message, this.cause);
                    }

                    throw e;
                } else if (evaluationDuration.compareTo(minWaitTime) < 0) {
                    String message = String.format("Condition was evaluated in %s %s which is earlier than expected " +
                                    "minimum timeout %s %s", evaluationDuration.getValue(), evaluationDuration.getTimeUnit(),
                            minWaitTime.getValue(), minWaitTime.getTimeUnit());
                    throw new ConditionTimeoutException(message);
                }
            } finally {
                executor.shutdown();
                if (!executor.awaitTermination(1, TimeUnit.SECONDS)) {
                    try {
                        executor.shutdownNow();
                        executor.awaitTermination(1, TimeUnit.SECONDS);
                    } catch (InterruptedException e) {
                        CheckedExceptionRethrower.safeRethrow(e);
                    }
                }
            }
        } catch (Throwable e) {
            CheckedExceptionRethrower.safeRethrow(e);
        }
    }

ConditionAwaiter类中有个CountDownLatch成员变量:

private final CountDownLatch latch;

实例化时,定义了:

this.latch = new CountDownLatch(1);

CountDownLatch:

A synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.

A CountDownLatch is initialized with a given count. The await methods block until the current count reaches zero due to invocations of the countDown() method, after which all waiting threads are released and any subsequent invocations of await return immediately.

详情见:<https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/CountDownLatch.html>

在ConditionAwaiter实例的await()方法中,创建了一个轮询线程:

 pollSchedulingThread(conditionEvaluationHandler, pollDelay, maxWaitTime).start();
 private <T> Thread pollSchedulingThread(final ConditionEvaluationHandler<T> conditionEvaluationHandler,
                                            final Duration pollDelay, final Duration maxWaitTime) {
        final long maxTimeout = maxWaitTime.getValue();
        final TimeUnit maxTimeoutUnit = maxWaitTime.getTimeUnit();

        return new Thread(new Runnable() {
            public void run() {
                int pollCount = 0;
                try {
                    conditionEvaluationHandler.start();
                    if (!pollDelay.isZero()) {
                        Thread.sleep(pollDelay.getValueInMS());
                    }
                    Duration pollInterval = pollDelay;

                    while (!executor.isShutdown()) {
                        if (conditionCompleted()) {
                            break;
                        }
                        pollCount = pollCount + 1;
                        Future<?> future = executor.submit(new ConditionPoller(pollInterval));
                        if (maxWaitTime == Duration.FOREVER) {
                            future.get();
                        } else {
                            future.get(maxTimeout, maxTimeoutUnit);
                        }
                        pollInterval = conditionSettings.getPollInterval().next(pollCount, pollInterval);
                        Thread.sleep(pollInterval.getValueInMS());
                    }
                } catch (Throwable e) {
                    throwable = e;
                }
            }
        }, "awaitility-poll-scheduling");
    }

其中while循环中:
Future<?> future = executor.submit(new ConditionPoller(pollInterval));
线程执行体:

private class ConditionPoller implements Runnable {
        private final Duration delayed;

        /**
         * @param delayed The duration of the poll interval
         */
        public ConditionPoller(Duration delayed) {
            this.delayed = delayed;
        }


        public void run() {
            try {
                ConditionEvaluationResult result = conditionEvaluator.eval(delayed);
                if (result.isSuccessful()) {
                    latch.countDown();
                } else if (result.hasThrowable()) {
                    cause = result.getThrowable();
                }
            } catch (Exception e) {
                if (!conditionSettings.shouldExceptionBeIgnored(e)) {
                    throwable = e;
                    latch.countDown();
                }
            }
        }
    }

具体执行:ConditionEvaluationResult result = conditionEvaluator.eval(delayed);

实例化ConditionAwaiter时传入ConditionEvaluator的实现类ConditionEvaluationWrapper;

conditionAwaiter = new ConditionAwaiter(conditionEvaluationWrapper, settings)

ConditionEvaluationWrapper中eval()方法:

public ConditionEvaluationResult eval(Duration pollInterval) throws Exception {
            boolean conditionFulfilled = matcher.call();
            if (conditionFulfilled) {
                conditionEvaluationHandler.handleConditionResultMatch(getMatchMessage(matcher, settings.getAlias()), true, pollInterval);
            } else {
                conditionEvaluationHandler.handleConditionResultMismatch(getMismatchMessage(matcher, settings.getAlias()), false, pollInterval);

            }
            return new ConditionEvaluationResult(conditionFulfilled);
        }

其中:
boolean conditionFulfilled = matcher.call();
call()返回computed result。

matcher实例在ConditionEvaluationWrapper构造函数中实例化:

     ConditionEvaluationWrapper(Callable<Boolean> matcher, ConditionSettings settings, ConditionEvaluationHandler<Object> conditionEvaluationHandler) {

            this.matcher = matcher;
            this.settings = settings;
            this.conditionEvaluationHandler = conditionEvaluationHandler;
        }

本例为:

new Callable<Boolean>() {
                    @Override
                    public Boolean call() throws Exception {
                        return service.getCount() == 6;
                    }
                }

如果异步执行结果满足,latch.countDown();

 ConditionEvaluationResult result = conditionEvaluator.eval(delayed);
                if (result.isSuccessful()) {
                    latch.countDown();
                } else if (result.hasThrowable()) {
                    cause = result.getThrowable();
                }

使latch.getCount() == 0,导致while循环break中断;
否则,异步执行结果不满足,每次while循环sleep:

Thread.sleep(pollInterval.getValueInMS());

while循环每次将eval()提交给线程池;如果是Duration.FOREVER一直等待执行结束;否则,最多等待maxTimeout查看执行结果。

 Future<?> future = executor.submit(new ConditionPoller(pollInterval));
 if (maxWaitTime == Duration.FOREVER) {
         future.get();
 } else {
         future.get(maxTimeout, maxTimeoutUnit);
 }

创建轮询线程后,判断latch是否为0,如果不为0,线程阻塞;

 latch.await(): Causes the current thread to wait until the latch has counted down to zero。

 latch.await(long timeout, TimeUnit unit):最多等待timeout, true if the count reached zero and false if the waiting time elapsed before the count reached zero
                final boolean finishedBeforeTimeout;
                if (maxWaitTime == Duration.FOREVER) {
                    latch.await();
                    finishedBeforeTimeout = true;
                } else {
                    finishedBeforeTimeout = latch.await(maxTimeout, maxTimeoutUnit);
                }

最后根据finishedBeforeTimeout为false,拼接提示语。

错误提示语:

org.awaitility.core.ConditionTimeoutException: Condition with alias 'count is greater 6' didn't complete within 6 seconds because condition returned by method "testAsynchronousPoll" in class org.awaitility.AwaitilityTest was not fulfilled.

3. 总结

关键流程总结如下:

定义CountDownLatch变量latch,并初始化为new CountDownLatch(1)

启动一个轮询线程,该轮询线程执行体中实现了while循环,每次先判断latch.getCount()是否为0,如果为0,跳出while循环;否则,将判断异步结果是否成功的任务提交给线程池executor执行,执行体会判断是否成功,成功则latch.countDown()(导致latch.getCount()为0,下次跳出while循环);同时,每次while循环执行 Thread.sleep(pollInterval.getValueInMS()); 如果轮询线程执行体while循环一直不满足条件,主线程将阻塞maxTimeoutUnit:latch.await(maxTimeout, maxTimeoutUnit), 如果latch.getCount()不为0,即异步校验不成功,finishedBeforeTimeout置为false, finishedBeforeTimeout = latch.await(maxTimeout, maxTimeoutUnit), 输出异常信息。

 类似资料: