当前位置: 首页 > 知识库问答 >
问题:

Java中Eratosthenes的多线程分段筛

孙洋
2023-03-14

我正在尝试在Java中创建一个快速的素数生成器。人们(或多或少)认为,最快的方法是埃拉托斯特尼的分段筛:https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes。可以进一步实施许多优化以使其更快。截至目前,我的实现在大约1.6秒内生成50847534低于10 ^ 9的素数,但我希望让它更快,至少打破1秒的障碍。为了增加获得良好回复的机会,我将包括算法和代码的演练。

尽管如此,作为 TL;DR,我希望将多线程包含在代码中。

为了回答这个问题,我想区分厄拉多塞的“分段”和“传统”筛子。传统的筛子需要< code>O(n)空间,因此在输入范围上非常有限(它的极限)。然而,分段筛只需要O(n^0.5)空间,并且可以在大得多的限制下操作。(考虑到<代码> L1,主要的加速是使用缓存友好的分段

这让我想到了多线程——将工作分成几个线程,每个线程处理更少量的工作,以更好地利用CPU。根据我的理解,传统的筛子不能完全多线程化,因为它是顺序的。每个线程都依赖于前一个,使得整个想法不可行。但是分段的筛子可能确实(我认为)是多线程的。

与其直接进入我的问题,我认为首先介绍我的代码是很重要的,所以我在此包括我目前最快的分段筛实现。我在这方面非常努力。花了相当长的时间,慢慢调整和添加优化。代码并不简单。我想说,它相当复杂。因此,我假设读者熟悉我介绍的概念,例如轮子分解、质数、分段等等。我包含了注释,以便更容易理解。

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;

public class primeGen {

    public static long x = (long)Math.pow(10, 9); //limit
    public static int sqrtx;
    public static boolean [] sievingPrimes; //the sieving primes, <= sqrtx

    public static int [] wheels = new int [] {2,3,5,7,11,13,17,19}; // base wheel primes
    public static int [] gaps; //the gaps, according to the wheel. will enable skipping multiples of the wheel primes
    public static int nextp; // the first prime > wheel primes
    public static int l; // the amount of gaps in the wheel

    public static void main(String[] args)
    {
        long startTime = System.currentTimeMillis();

        preCalc();  // creating the sieving primes and calculating the list of gaps

        int segSize = Math.max(sqrtx, 32768*8); //size of each segment
        long u = nextp; // 'u' is the running index of the program. will continue from one segment to the next
        int wh = 0; // the will be the gap index, indicating by how much we increment 'u' each time, skipping the multiples of the wheel primes

        long pi = pisqrtx(); // the primes count. initialize with the number of primes <= sqrtx

        for (long low = 0 ; low < x ; low += segSize) //the heart of the code. enumerating the primes through segmentation. enumeration will begin at p > sqrtx
        {
            long high = Math.min(x, low + segSize);
            boolean [] segment = new boolean [(int) (high - low + 1)];

            int g = -1;
            for (int i = nextp ; i <= sqrtx ; i += gaps[g])
            { 
                if (sievingPrimes[(i + 1) / 2])
                {
                    long firstMultiple = (long) (low / i * i);
                    if (firstMultiple < low) 
                        firstMultiple += i; 
                    if (firstMultiple % 2 == 0) //start with the first odd multiple of the current prime in the segment
                        firstMultiple += i;

                    for (long j = firstMultiple ; j < high ; j += i * 2) 
                        segment[(int) (j - low)] = true; 
                }
                g++;
                //if (g == l) //due to segment size, the full list of gaps is never used **within just one segment** , and therefore this check is redundant. 
                              //should be used with bigger segment sizes or smaller lists of gaps
                    //g = 0;
            }

            while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            }
        }

        System.out.println(pi);

        long endTime = System.currentTimeMillis();
        System.out.println("Solution took "+(endTime - startTime) + " ms");
    }

    public static boolean [] simpleSieve (int l)
    {
        long sqrtl = (long)Math.sqrt(l);
        boolean [] primes = new boolean [l/2+2];
        Arrays.fill(primes, true);
        int g = -1;
        for (int i = nextp ; i <= sqrtl ; i += gaps[g])
        {
            if (primes[(i + 1) / 2])
                for (int j = i * i ; j <= l ; j += i * 2)
                    primes[(j + 1) / 2]=false;
            g++;
            if (g == l)
                g=0;
        }
        return primes;
    }

    public static long pisqrtx ()
    {
        int pi = wheels.length;
        if (x < wheels[wheels.length-1])
        {
            if (x < 2)
                return 0;
            int k = 0;
            while (wheels[k] <= x)
                k++;
            return k;
        }
        int g = -1;
        for (int i = nextp ; i <= sqrtx ; i += gaps[g])
        {
            if(sievingPrimes[( i + 1 ) / 2])
                pi++;
            g++;
            if (g == l)
                g=0;
        }

        return pi;
    }

    public static void preCalc ()
    {
        sqrtx = (int) Math.sqrt(x);

        int prod = 1;
        for (long p : wheels)
            prod *= p; // primorial
        nextp = BigInteger.valueOf(wheels[wheels.length-1]).nextProbablePrime().intValue(); //the first prime that comes after the wheel
        int lim = prod + nextp; // circumference of the wheel

        boolean [] marks = new boolean [lim + 1];
        Arrays.fill(marks, true);

        for (int j = 2 * 2 ;j <= lim ; j += 2)
            marks[j] = false;
        for (int i = 1 ; i < wheels.length ; i++)
        {
            int p = wheels[i];
            for (int j = p * p ; j <= lim ; j += 2 * p)
                marks[j]=false;   // removing all integers that are NOT comprime with the base wheel primes
        }
        ArrayList <Integer> gs = new ArrayList <Integer>(); //list of the gaps between the integers that are coprime with the base wheel primes
        int d = nextp;
        for (int p = d + 2 ; p < marks.length ; p += 2)
        {
            if (marks[p]) //d is prime. if p is also prime, then a gap is identified, and is noted.
            {
                gs.add(p - d);
                d = p;
            }
        }
        gaps = new int [gs.size()];
        for (int i = 0 ; i < gs.size() ; i++)
            gaps[i] = gs.get(i); // Arrays are faster than lists, so moving the list of gaps to an array
        l = gaps.length;

        sievingPrimes = simpleSieve(sqrtx); //initializing the sieving primes
    }

}

目前,它在大约1.6秒内产生低于10^9的50847534个素数。这非常令人印象深刻,至少以我的标准来看,但我希望能更快,可能会打破< code>1秒的障碍。即便如此,我相信它还可以做得更快。

整个程序基于轮子分解:https://en.wikipedia.org/wiki/Wheel_factorization.我注意到我使用所有素数的轮子得到最快的结果,最高可达19

public static int [] wheels = new int [] {2,3,5,7,11,13,17,19}; // base wheel primes

这意味着这些素数的倍数被跳过,导致搜索范围更小。然后在preCalc方法中计算我们需要的数字之间的间距。如果我们在搜索范围内的数字之间跳转,我们就会跳过基素数的倍数。

public static void preCalc ()
    {
        sqrtx = (int) Math.sqrt(x);

        int prod = 1;
        for (long p : wheels)
            prod *= p; // primorial
        nextp = BigInteger.valueOf(wheels[wheels.length-1]).nextProbablePrime().intValue(); //the first prime that comes after the wheel
        int lim = prod + nextp; // circumference of the wheel

        boolean [] marks = new boolean [lim + 1];
        Arrays.fill(marks, true);

        for (int j = 2 * 2 ;j <= lim ; j += 2)
            marks[j] = false;
        for (int i = 1 ; i < wheels.length ; i++)
        {
            int p = wheels[i];
            for (int j = p * p ; j <= lim ; j += 2 * p)
                marks[j]=false;   // removing all integers that are NOT comprime with the base wheel primes
        }
        ArrayList <Integer> gs = new ArrayList <Integer>(); //list of the gaps between the integers that are coprime with the base wheel primes
        int d = nextp;
        for (int p = d + 2 ; p < marks.length ; p += 2)
        {
            if (marks[p]) //d is prime. if p is also prime, then a gap is identified, and is noted.
            {
                gs.add(p - d);
                d = p;
            }
        }
        gaps = new int [gs.size()];
        for (int i = 0 ; i < gs.size() ; i++)
            gaps[i] = gs.get(i); // Arrays are faster than lists, so moving the list of gaps to an array
        l = gaps.length;

        sievingPrimes = simpleSieve(sqrtx); //initializing the sieving primes
    } 

preCalc方法的末尾,调用simpleSieve方法,高效筛选之前提到的所有筛选素数,素数

 public static boolean [] simpleSieve (int l)
    {
        long sqrtl = (long)Math.sqrt(l);
        boolean [] primes = new boolean [l/2+2];
        Arrays.fill(primes, true);
        int g = -1;
        for (int i = nextp ; i <= sqrtl ; i += gaps[g])
        {
            if (primes[(i + 1) / 2])
                for (int j = i * i ; j <= l ; j += i * 2)
                    primes[(j + 1) / 2]=false;
            g++;
            if (g == l)
                g=0;
        }
        return primes;
    } 

最后,我们到达了算法的核心。我们首先枚举所有素数

 long pi = pisqrtx();`

使用以下方法:

public static long pisqrtx ()
    {
        int pi = wheels.length;
        if (x < wheels[wheels.length-1])
        {
            if (x < 2)
                return 0;
            int k = 0;
            while (wheels[k] <= x)
                k++;
            return k;
        }
        int g = -1;
        for (int i = nextp ; i <= sqrtx ; i += gaps[g])
        {
            if(sievingPrimes[( i + 1 ) / 2])
                pi++;
            g++;
            if (g == l)
                g=0;
        }

        return pi;
    } 

然后,在初始化跟踪素数枚举的pi变量后,我们html" target="_blank">执行上述分段,从第一个素数开始枚举

 int segSize = Math.max(sqrtx, 32768*8); //size of each segment
        long u = nextp; // 'u' is the running index of the program. will continue from one segment to the next
        int wh = 0; // the will be the gap index, indicating by how much we increment 'u' each time, skipping the multiples of the wheel primes

        long pi = pisqrtx(); // the primes count. initialize with the number of primes <= sqrtx

        for (long low = 0 ; low < x ; low += segSize) //the heart of the code. enumerating the primes through segmentation. enumeration will begin at p > sqrtx
        {
            long high = Math.min(x, low + segSize);
            boolean [] segment = new boolean [(int) (high - low + 1)];

            int g = -1;
            for (int i = nextp ; i <= sqrtx ; i += gaps[g])
            { 
                if (sievingPrimes[(i + 1) / 2])
                {
                    long firstMultiple = (long) (low / i * i);
                    if (firstMultiple < low) 
                        firstMultiple += i; 
                    if (firstMultiple % 2 == 0) //start with the first odd multiple of the current prime in the segment
                        firstMultiple += i;

                    for (long j = firstMultiple ; j < high ; j += i * 2) 
                        segment[(int) (j - low)] = true; 
                }
                g++;
                //if (g == l) //due to segment size, the full list of gaps is never used **within just one segment** , and therefore this check is redundant. 
                              //should be used with bigger segment sizes or smaller lists of gaps
                    //g = 0;
            }

            while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            }
        } 

我也把它作为注释包括在内,但也会解释。由于段大小相对较小,因此我们不会仅浏览一个段中的整个间隙列表,并且检查它是多余的。(假设我们使用19轮)。但是在程序的更广泛范围概述中,我们将利用整个间隙数组,因此变量u必须跟随它,而不是意外地超过它:

 while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            } 

使用更高的限制最终会呈现更大的段,这可能会导致有必要检查我们甚至在段内也没有超过间隙列表。这,或者调整素数基础可能会对程序产生这种影响。不过,切换到位筛选可以在很大程度上提高段限制。

  • 作为一个重要的补充说明,我知道有效的分割需要L1

另外,我还想听听更多关于加速这个项目的方法,你有什么想法,我都想听!真的很想让它变得又快又高效。谢谢你!


共有3个答案

符修杰
2023-03-14

你对速度有多感兴趣?你会考虑使用c吗?

$ time ../c_code/segmented_bit_sieve 1000000000
50847534 primes found.

real    0m0.875s
user    0m0.813s
sys     0m0.016s
$ time ../c_code/segmented_bit_isprime 1000000000
50847534 primes found.

real    0m0.816s
user    0m0.797s
sys     0m0.000s

(在我配有i5的新笔记本电脑上)

第一个来自@Kim Walisch,使用了一组奇怪的质数候选人。

https://github.com/kimwalisch/primesieve/wiki/Segmented-sieve-of-Eratosthenes

第二个是我对Kim的调整,IsPrime[]也实现为位数组,读取起来稍微不太清晰,尽管由于内存占用减少,对于大N来说有点快。

我会仔细阅读你的帖子,因为无论使用什么语言,我都对素数和表现感兴趣。我希望这不是离题太远或为时过早。但我注意到我已经超出了你的绩效目标。

闾丘卓
2023-03-14

你熟悉托马斯·奥利维拉·席尔瓦的作品吗?他很快就实现了埃拉托斯特尼的筛选。

方茂
2023-03-14

这样的例子应该有助于你开始。

解决方案的概要:

  • 定义一个包含特定段的数据结构(“任务”),您也可以将所有不可变的共享数据放入其中,以获得额外的整洁。如果您足够小心,可以将一个通用可变数组以及段限制传递给所有任务,并且只更新这些限制内的数组部分。这更容易出错,但可以简化合并结果的步骤(AFAICT;嗯嗯
  • 定义存储 Task 计算结果的数据结构(“结果”)。即使您只是更新共享的结果结构,也可能需要指示到目前为止该结构的哪一部分已更新。
  • 创建一个接受任务、运行计算并将结果放入给定结果队列的 Runnable。
  • 为任务创建一个阻塞输入队列,并为结果创建一个队列。
  • 创建线程池演示器,其线程数接近计算机内核数。
  • 将所有任务提交到线程池执行器。它们将被安排在池中的线程上运行,并将结果放入输出队列中,不一定按顺序排列。
  • 等待线程池中的所有任务完成。
  • 清空输出队列,并将部分结果联接到最终结果中。

通过将结果连接到读取输出队列的单独任务中,或者甚至通过更新synchronized下的可变共享输出结构(取决于连接步骤所涉及的工作量),可以实现(也可以不实现)额外的加速。

希望这有帮助。

 类似资料:
  • 我找到了关于线程安全的代码,但它没有来自给出示例的人的任何解释。我想知道为什么如果我不在“count”之前设置“synchronized”变量,那么count值将是非原子的(总是=200是期望的结果)。谢谢

  • 我有一个关于JAVA多线程的问题。 我有一个jetty webapp与grpc-流式传输-客户端。一切都很好,但我如何建立一个模型来获取流式传输数据? webapp是用jsf构建的。因为我有一个控制器,它调用一个处理程序类来启动流: 此方法简单地启动客户端和流。 检查倒计时锁存器的实现仍然缺失。但在这种情况下,这并不重要。 响应如下:onNext()-方法提供流式数据: 图像数据简单地打印在屏幕上

  • 这是一个关于Java中多线程的初学者问题。 根据我的理解,当创建多个(用户)线程来运行程序或应用程序时,就没有父线程和子线程的概念。它们都是独立的用户线程。 因此,如果主线程完成执行,那么另一个线程(Thread2)仍将继续执行,因为在Thread2的执行线程完成之前,它不会被JVM杀死(https://docs.oracle.com/javase/6/docs/api/java/lang/Thr

  • 问题内容: 在多线程环境中使用Singleton类的首选方法是什么? 假设我有3个线程,并且所有这些线程都尝试同时访问单例类的方法- 如果不保持同步会怎样? 在内部使用 方法还是使用块是好的做法。 请告知是否还有其他出路。 问题答案: 从理论上讲,这项任务并不容易,因为您要使其真正成为线程安全的。 在此上找到了一篇非常不错的论文@ IBM 仅获取单例不需要任何同步,因为这只是读取。因此,只需同步S

  • 问题内容: 鉴于以下多态: 我们如何在没有昂贵的getInstance()方法同步和双重检查锁定争议的情况下使它保持线程安全和懒惰?这里提到了单例的有效方法,但似乎并没有扩展到多例。 问题答案: 使用Java 8,它甚至可以更简单:

  • 问题内容: 我有一个缓存,该缓存是使用Simeple HashMap实现的。喜欢 - 大部分时间都使用此缓存从中读取值。我有另一个重新加载缓存的方法,在这个方法的内部,我基本上创建了一个新的缓存,然后分配了引用。据我了解,对象引用的分配是Java中的Atomic。 我了解,如果我不将缓存声明为易失性,则其他线程将无法看到更改,但是对于我的用例而言,将缓存中的更改传播到其他线程不是时间紧迫的,它们可