当前位置: 首页 > 知识库问答 >
问题:

有效计算三个无符号整数的平均值(无溢出)

宋博易
2023-03-14

存在一个“3个长整数的平均值”问题,该问题特别涉及三个有符号整数的平均值的有效计算。

然而,使用无符号整数允许进行不适用于前一个问题中所述场景的额外优化。这个问题是关于三个无符号整数的平均值的有效计算,其中平均值向零舍入,也就是说,用我想要计算的数学术语⌊ (a b c)/3⌋.

计算该平均值的简单方法是

 avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;

首先,现代优化编译器将除法转换为带倒数加移位的乘法,模运算转换为反乘法和减法,其中反乘法可以使用许多架构上可用的scale\u-add惯用法,例如x86\u 64上的lea,ARM上的lsl n,NVIDIA GPU上的iscadd。

在尝试以适用于许多常见平台的通用方式优化上述内容时,我观察到整数运算的成本通常在逻辑关系中≤ (添加| sub)≤ 转移≤ 缩放添加≤ 这里的mul.Cost指的是所有延迟、吞吐量限制和功耗。当处理的整数类型比本机寄存器宽度宽时,任何此类差异都会变得更加明显,例如在32位处理器上处理uint64\u t数据时。

因此,我的优化策略是最大限度地减少指令数,并在可能的情况下用“廉价”操作替换“昂贵”操作,同时不增加寄存器压力,并为大量无序处理器保留可利用的并行性。

第一个观察是,我们可以通过首先应用产生总和值和进位值的CSA(进位保存加法器)将三个操作数的总和减少为两个操作数的总和,其中进位值的权重是总和值的两倍。大多数处理器上基于软件的CSA的成本是五个逻辑。一些处理器,如NVIDIA GPU,有一个LOP3指令,可以一举计算三个操作数的任意逻辑表达式,在这种情况下,CSA浓缩为两个LOP3s(注意:我还没有说服CUDA编译器发出这两个LOP3s;它目前产生四个LOP3s!)。

第二个观察结果是,因为我们正在计算除以3的模,所以我们不需要反乘法来计算它。我们可以改为使用股息%3((股息/3)股息)

最后,对于校正项(a%3 b%3 c%3)/3中的除3,我们不需要一般除3的代码。由于被除数非常小,在[0,6]中,我们可以将x/3简化为(3*x)/8,只需要一个小数加一个移位。

下面的代码显示了我当前正在进行的工作。使用编译器资源管理器检查为各种平台生成的代码显示了我所期望的紧凑代码(当使用O3编译时)。

然而,在使用Intel 13.x编译器的Ivy Bridge x86\u 64机器上对代码进行计时时,一个缺陷变得很明显:与简单版本相比,我的代码提高了延迟(uint64\u数据从18个周期提高到15个周期),但吞吐量恶化(uint64\u数据从每6.8个周期一个结果提高到每8.5个周期一个结果)。更仔细地看汇编代码,很明显这是为什么:我基本上设法将代码从大致的三向并行性降低到大致的双向并行性。

是否有一种普遍适用的优化技术,有利于通用处理器,特别是所有类型的x86和ARM以及GPU,可以保留更多的并行性?或者,是否有一种优化技术可以进一步减少整体操作数量以弥补减少的并行性?更正项(在下面的代码中)的计算似乎是一个很好的目标。简化(carry_mod_3sum_mod_3)/2看起来很诱人,但对九种可能的组合之一提供了不正确的结果。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>

#define BENCHMARK           (1)
#define SIMPLE_COMPUTATION  (0)

#if BENCHMARK
#define T uint64_t
#else // !BENCHMARK
#define T uint8_t
#endif // BENCHMARK

T average_of_3 (T a, T b, T c) 
{
    T avg;

#if SIMPLE_COMPUTATION
    avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
#else // !SIMPLE_COMPUTATION
    /* carry save adder */
    T a_xor_b = a ^ b;
    T sum = a_xor_b ^ c;
    T carry = (a_xor_b & c) | (a & b);
    /* here 2 * carry + sum = a + b + c */
    T sum_div_3 = (sum / 3);                                   // {MUL|MULHI}, SHR
    T sum_mod_3 = (sum + sum_div_3) & 3;                       // ADD, AND

    if (sizeof (size_t) == sizeof (T)) { // "native precision" (well, not always)
        T two_carry_div_3 = (carry / 3) * 2;                   // MULHI, ANDN
        T two_carry_mod_3 = (2 * carry + two_carry_div_3) & 6; // SCALE_ADD, AND
        T head = two_carry_div_3 + sum_div_3;                  // ADD
        T tail = (3 * (two_carry_mod_3 + sum_mod_3)) / 8;      // ADD, SCALE_ADD, SHR
        avg = head + tail;                                     // ADD
    } else {
        T carry_div_3 = (carry / 3);                           // MUL, SHR
        T carry_mod_3 = (carry + carry_div_3) & 3;             // ADD, AND
        T head = (2 * carry_div_3 + sum_div_3);                // SCALE_ADD
        T tail = (3 * (2 * carry_mod_3 + sum_mod_3)) / 8;      // SCALE_ADD, SCALE_ADD, SHR
        avg = head + tail;                                     // ADD
    }
#endif // SIMPLE_COMPUTATION
    return avg;
}

#if !BENCHMARK
/* Test correctness on 8-bit data exhaustively. Should catch most errors */
int main (void)
{
    T a, b, c, res, ref;
    a = 0;
    do {
        b = 0;
        do {
            c = 0;
            do {
                res = average_of_3 (a, b, c);
                ref = ((uint64_t)a + (uint64_t)b + (uint64_t)c) / 3;
                if (res != ref) {
                    printf ("a=%08x  b=%08x  c=%08x  res=%08x  ref=%08x\n", 
                            a, b, c, res, ref);
                    return EXIT_FAILURE;
                }
                c++;
            } while (c);
            b++;
        } while (b);
        a++;
    } while (a);
    return EXIT_SUCCESS;
}

#else // BENCHMARK

#include <math.h>

// A routine to give access to a high precision timer on most systems.
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

#define N  (3000000)
int main (void)
{
    double start, stop, elapsed = INFINITY;
    int i, k;
    T a, b;
    T avg0  = 0xffffffff,  avg1 = 0xfffffffe;
    T avg2  = 0xfffffffd,  avg3 = 0xfffffffc;
    T avg4  = 0xfffffffb,  avg5 = 0xfffffffa;
    T avg6  = 0xfffffff9,  avg7 = 0xfffffff8;
    T avg8  = 0xfffffff7,  avg9 = 0xfffffff6;
    T avg10 = 0xfffffff5, avg11 = 0xfffffff4;
    T avg12 = 0xfffffff2, avg13 = 0xfffffff2;
    T avg14 = 0xfffffff1, avg15 = 0xfffffff0;

    a = 0x31415926;
    b = 0x27182818;
    avg0 = average_of_3 (a, b, avg0);
    for (k = 0; k < 5; k++) {
        start = second();
        for (i = 0; i < N; i++) {
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            b = (b + avg0) ^ a;
            a = (a ^ b) + avg0;
        }
        stop = second();
        elapsed = fmin (stop - start, elapsed);
    }
    printf ("a=%016llx b=%016llx avg=%016llx", 
            (uint64_t)a, (uint64_t)b, (uint64_t)avg0);
    printf ("\rlatency:    each average_of_3() took  %.6e seconds\n", 
            elapsed / 16 / N);


    a = 0x31415926;
    b = 0x27182818;
    avg0 = average_of_3 (a, b, avg0);
    for (k = 0; k < 5; k++) {
        start = second();
        for (i = 0; i < N; i++) {
            avg0  = average_of_3 (a, b, avg0);
            avg1  = average_of_3 (a, b, avg1);
            avg2  = average_of_3 (a, b, avg2);
            avg3  = average_of_3 (a, b, avg3);
            avg4  = average_of_3 (a, b, avg4);
            avg5  = average_of_3 (a, b, avg5);
            avg6  = average_of_3 (a, b, avg6);
            avg7  = average_of_3 (a, b, avg7);
            avg8  = average_of_3 (a, b, avg8);
            avg9  = average_of_3 (a, b, avg9);
            avg10 = average_of_3 (a, b, avg10);
            avg11 = average_of_3 (a, b, avg11);
            avg12 = average_of_3 (a, b, avg12);
            avg13 = average_of_3 (a, b, avg13);
            avg14 = average_of_3 (a, b, avg14);
            avg15 = average_of_3 (a, b, avg15);
            b = (b + avg0) ^ a;
            a = (a ^ b) + avg0;
        }
        stop = second();
        elapsed = fmin (stop - start, elapsed);
    }
    printf ("a=%016llx b=%016llx avg=%016llx", (uint64_t)a, (uint64_t)b, 
            (uint64_t)(avg0 + avg1 + avg2 + avg3 + avg4 + avg5 + avg6 + avg7 + 
                       avg8 + avg9 +avg10 +avg11 +avg12 +avg13 +avg14 +avg15));
    printf ("\rthroughput: each average_of_3() took  %.6e seconds\n", 
            elapsed / 16 / N);

    return EXIT_SUCCESS;
}

#endif // BENCHMARK

共有3个答案

南宫泓
2023-03-14

我已经回答了你链接到的问题,所以我只回答了这个问题不同的部分:性能。

如果你真的关心性能,那么答案是:

( a + b + c ) / 3

因为您关心性能,所以您应该对正在处理的数据的大小有直觉。您不应该担心只有3个值的加法(乘法是另一个问题)溢出,因为如果您的数据已经足够大,可以使用所选数据类型的高位,那么您无论如何都有溢出的危险,应该使用更大的整数类型。如果您在uint64\u t上溢出,那么您应该真正扪心自问,为什么您需要准确地数到18个五分之一,也许可以考虑使用浮点或双精度。

现在,我已经说了所有这些,我将给你我的实际答复:这无关紧要。这个问题在现实生活中不会出现,当它出现时,性能并不重要。

如果你在SIMD中做了一百万次,这可能是一个真正的性能问题,因为在那里,你真的被激励使用较小宽度的整数,你可能需要最后一点空间,但这不是你的问题。

高朝明
2023-03-14

我不确定它是否符合您的要求,但可能只需计算结果,然后修复溢出的错误即可:

T average_of_3 (T a, T b, T c)
{
    T r = ((T) (a + b + c)) / 3;
    T o = (a > (T) ~b) + ((T) (a + b) > (T) (~c));
    if (o) r += ((T) 0x5555555555555555) << (o - 1);
    T rem = ((T) (a + b + c)) % 3;
    if (rem >= (3 - o)) ++r;
    return r;
}

[编辑]这是我能想到的最好的分支和比较少的版本。在我的机器上,这个版本实际上比njuffa的代码具有略高的吞吐量\u builtin\u add\u overflow(x,y,r)受gcc和clang支持,如果总和x y溢出*r和0的类型,则返回1,否则返回0,因此o的计算相当于第一个版本中的可移植代码,但至少gcc使用内置代码生成了更好的代码。

T average_of_3 (T a, T b, T c)
{
    T r = ((T) (a + b + c)) / 3;
    T rem = ((T) (a + b + c)) % 3;
    T dummy;
    T o = __builtin_add_overflow(a, b, &dummy) + __builtin_add_overflow((T) (a + b), c, &dummy);
    r += -((o - 1) & 0xaaaaaaaaaaaaaaab) ^ 0x5555555555555555;
    r += (rem + o + 1) >> 2;
    return r;
}

周泰
2023-03-14

让我把帽子扔进拳击场。我认为,在这里不要做太棘手的事情。

#include <stdint.h>

uint64_t average_of_three(uint64_t a, uint64_t b, uint64_t c) {
  uint64_t hi = (a >> 32) + (b >> 32) + (c >> 32);
  uint64_t lo = hi + (a & 0xffffffff) + (b & 0xffffffff) + (c & 0xffffffff);
  return 0x55555555 * hi + lo / 3;
}

下面讨论了不同的拆分,下面是一个以三个位AND为代价保存乘法的版本:

T hi = (a >> 2) + (b >> 2) + (c >> 2);
T lo = (a & 3) + (b & 3) + (c & 3);
avg = hi + (hi + lo) / 3;
 类似资料:
  • 问题内容: 我希望我能弄清楚。我需要生成一个平均值为AVG_AMT(整数)的表,并且没有小数。它可以舍入或截断。这张桌子真的没关系。 这是我试图写的: 有什么建议? 问题答案:

  • 我正在读一篇关于整数安全性的文章。以下是链接:http://ptgmedia.pearsoncmg.com/images/0321335724/samplechapter/seacord_ch05.pdf 在第166页,有这样一句话: 涉及无符号操作数的计算永远不会过流,因为不能由结果无符号整数类型表示的结果将被模化为比结果类型可以表示的最大值大一的数字。 这是什么意思?感谢您的回复。

  • 我正在简单的C程序中试验无符号int数据类型和主方法参数。作为一个实验,我写了一个程序,从命令行获取一个int数作为main方法的参数,并对该数和0之间的每个整数求和。 例如,程序计算 f(n) = (1 2 3... n) 当 n 时有效 我开始注意到的第一件事是当f(n) 我手动发现数学上的最大值,我的程序生成的结果将是有效的(例如,在整数溢出之前),对于有符号整数为65535,对于无符号in

  • 我需要写一个程序来计算用户输入的整数的奇偶平均数。用户键入“完成”以完成。输出将显示奇数的平均值和偶数的平均值。 我有一个while循环程序,可以计算数字的和,我正试图增加奇数和偶数和的额外要求。这是代码: 下面是我修改的代码,对奇数和偶数进行排序,然后对每组进行平均。 预期: 实际:

  • 问题内容: 已关闭 。这个问题需要细节或说明。它当前不接受答案。 想改善这个问题吗? 添加详细信息并通过编辑此帖子来澄清问题。 11个月前关闭。 改善这个问题 我有一个清单: 我想要另一个具有三个值均值的列表,因此新列表为: 新列表中只有6个值,因为第一个元素中只有18个元素。 我正在寻找一种精巧的方法来完成此操作,并为大量列表提供最少的步骤。 问题答案: 您可以在3个间隔中迭代使用for循环

  • 未定义行为的一个例子是在flow上的整数行为 有没有一个历史的或者(甚至更好!)造成这种差异的技术原因是什么?