我有一个程序,它花费大部分时间计算RGB值之间的欧几里德距离(无符号8位Word8的3元组)。我需要一个快速、无分支的无符号int绝对差分函数,这样
unsigned_difference :: Word8 -> Word8 -> Word8
unsigned_difference a b = max a b - min a b
特别是,
unsigned_differencea b==unsigned_differenceb a
使用GHC 7.8中的新primops,我得出了以下结论:
-- (a < b) * (b - a) + (a > b) * (a - b)
unsigned_difference (I# a) (I# b) =
I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))]
ghc-O2-S编译为
.Lc42U:
movq 7(%rbx),%rax
movq $ghczmprim_GHCziTypes_Izh_con_info,-8(%r12)
movq 8(%rbp),%rbx
movq %rbx,%rcx
subq %rax,%rcx
cmpq %rax,%rbx
setg %dl
movzbl %dl,%edx
imulq %rcx,%rdx
movq %rax,%rcx
subq %rbx,%rcx
cmpq %rax,%rbx
setl %al
movzbl %al,%eax
imulq %rcx,%rax
addq %rdx,%rax
movq %rax,(%r12)
leaq -7(%r12),%rbx
addq $16,%rbp
jmp *(%rbp)
使用ghc-O2-fllvm-optlo-O3-S
编译会产生以下asm:
.LBB6_1:
movq 7(%rbx), %rsi
movq $ghczmprim_GHCziTypes_Izh_con_info, 8(%rax)
movq 8(%rbp), %rcx
movq %rsi, %rdx
subq %rcx, %rdx
xorl %edi, %edi
subq %rsi, %rcx
cmovleq %rdi, %rcx
cmovgeq %rdi, %rdx
addq %rcx, %rdx
movq %rdx, 16(%rax)
movq 16(%rbp), %rax
addq $16, %rbp
leaq -7(%r12), %rbx
jmpq *%rax # TAILCALL
所以LLVM设法用(更有效?)条件移动指令替换比较。不幸的是,使用-fllvm
编译对我的程序的运行时几乎没有影响。
然而,这个功能有两个问题。
我已经分析并确认了fromIntegral:: Word8的使用-
我的版本使用2个比较、2个乘法和2个减法。我想知道是否有一种更有效的方法,使用位运算或SIMD指令,并利用我正在比较Word8这一事实
我之前标记了问题
C/C
以吸引那些更倾向于位操作的人的注意力。我的问题使用Haskell,但我接受在任何语言中实现正确方法的答案。
结论:
我决定使用
w8_sad :: Word8 -> Word8 -> Int16
w8_sad a b = xor (diff + mask) mask
where diff = fromIntegral a - fromIntegral b
mask = unsafeShiftR diff 15
因为它比我原来的
unsigned\u difference
函数快,而且实现简单。Haskell中的SIMD内部函数尚未成熟。因此,虽然SIMD版本更快,但我决定使用标量版本。
编辑:更改我的答案,我为此错误配置了优化。
我用C语言建立了一个快速测试平台,我发现
<代码>a-b(a
至少在我的设置中,头发更好吗。我的方法的优点是消除了比较。当不需要时,您的版本会隐式地处理a-b==0,就像它是一个单独的情况一样。
我和你的测试
我尝试了一种使用非分支绝对值的方法,结果更好。请注意,输入或输出是否被编译器认为是有符号的并不重要。它围绕大的无符号值循环,但由于它只需要处理小的值(如问题所述),因此应该足够了。
s32 diff = a - b;
u32 mask = diff >> 31;
return (diff + mask) ^ mask;
如果您针对的是一个带有SSE指令的系统,那么您可以使用它来提高性能。我用其他发布的方法对此进行了测试,这似乎是最快的方法。
区分大量值的示例结果:
diff0: 188.020679 ms // branching
diff1: 118.934970 ms // max min
diff2: 97.087710 ms // branchless mul add
diff3: 54.495269 ms // branchless signed
diff4: 31.159628 ms // sse
diff5: 30.855885 ms // sse v2
下面是我的完整测试代码。我通过SSE内部函数使用SSE2指令,SSE2指令现在在x86ish CPU中广泛可用,SSE内部函数应该是可移植的(MSVC、GCC、Clang、Intel编译器等)。
注意事项:
如果您对代码或这种方法有任何问题/建议,请发表评论。
#include <cstdlib>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <random>
#include <algorithm>
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <Windows.h>
#include <emmintrin.h> // sse2
// branching
void diff0(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
}
}
// max min
void diff1(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = std::max(a[i], b[i]) - std::min(a[i], b[i]);
}
}
// branchless mul add
void diff2(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = (a[i] > b[i]) * (a[i] - b[i]) + (a[i] < b[i]) * (b[i] - a[i]);
}
}
// branchless signed
void diff3(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
// sse
void diff4(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
auto pA = reinterpret_cast<const __m128i*>(a);
auto pB = reinterpret_cast<const __m128i*>(b);
auto pRes = reinterpret_cast<__m128i*>(res);
std::size_t i = 0;
for (std::size_t j = n / 16; j--; i++) {
__m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
__m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
_mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
}
for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
// sse v2
void diff5(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
auto pA = reinterpret_cast<const __m128i*>(a);
auto pB = reinterpret_cast<const __m128i*>(b);
auto pRes = reinterpret_cast<__m128i*>(res);
std::size_t i = 0;
const std::size_t UNROLL = 2;
for (std::size_t j = n / (16 * UNROLL); j--; i += UNROLL) {
__m128i max0 = _mm_max_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
__m128i min0 = _mm_min_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
__m128i max1 = _mm_max_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
__m128i min1 = _mm_min_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
_mm_store_si128(pRes + i + 0, _mm_sub_epi8(max0, min0));
_mm_store_si128(pRes + i + 1, _mm_sub_epi8(max1, min1));
}
for (std::size_t j = n % (16 * UNROLL) / 16; j--; i++) {
__m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
__m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
_mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
}
for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
int main() {
const std::size_t ALIGN = 16; // sse requires 16 bit align
const std::size_t N = 10 * 1024 * 1024 * 3;
auto a = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
auto b = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
{ // fill with random values
std::mt19937 engine(std::random_device{}());
std::uniform_int<std::uint8_t> distribution(0, 255);
for (std::size_t i = 0; i < N; i++) {
a[i] = distribution(engine);
b[i] = distribution(engine);
}
}
auto res0 = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff0 results
auto resX = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff1+ results
LARGE_INTEGER f, t0, t1;
QueryPerformanceFrequency(&f);
QueryPerformanceCounter(&t0);
diff0(a, b, res0, N);
QueryPerformanceCounter(&t1);
printf("diff0: %.6f ms\n",
static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);
#define TEST(diffX)\
QueryPerformanceCounter(&t0);\
diffX(a, b, resX, N);\
QueryPerformanceCounter(&t1);\
printf("%s: %.6f ms\n", #diffX,\
static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);\
for (std::size_t i = 0; i < N; i++) {\
if (resX[i] != res0[i]) {\
printf("error: %s(%03u, %03u) == %03u != %03u\n", #diffX,\
a[i], b[i], resX[i], res0[i]);\
break;\
}\
}
TEST(diff1);
TEST(diff2);
TEST(diff3);
TEST(diff4);
TEST(diff5);
_mm_free(a);
_mm_free(b);
_mm_free(res0);
_mm_free(resX);
getc(stdin);
return 0;
}
嗯,我试着做了一点基准测试。我使用标准作为基准,因为它进行适当的显著性测试。我还在这里使用QuickCheck来确保所有方法返回相同的结果。
我使用GHC 7.6.3(很遗憾,我无法包含primops函数)和O3进行编译:
ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff
首先,我们可以看到天真的实现和一点虚幻之间的区别:
absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 a b = max a b - min a b
absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
mask = unsafeShiftR v 63
输出:
benchmarking absdiff_Word8/1
mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
....
benchmarking absdiff_Word8/2
mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
...
我使用了“Bit Twiddling Hacks在这里”中的绝对整数值技巧。不幸的是,我们需要强制转换,我认为仅在Word8
领域无法很好地解决问题,但无论如何使用本机整数类型似乎是明智的(不过绝对没有必要创建堆对象)。
这看起来并没有太大的区别,但我的测试设置也不完美:我将函数映射到一个大的随机值列表上,以排除分支预测,使分支版本看起来比实际更有效。这会导致内存中堆积内存,这可能会对计时产生很大影响。当我们减去维护列表的恒定开销时,我们很可能会看到比20%的加速更多的东西。
生成的程序集实际上相当不错(这是函数的内联版本):
.Lc4BB:
leaq 7(%rbx),%rax
movq 8(%rbp),%rbx
subq (%rax),%rbx
movq %rbx,%rax
sarq $63,%rax
movq $base_GHCziInt_I64zh_con_info,-8(%r12)
addq %rax,%rbx
xorq %rax,%rbx
movq %rbx,0(%r12)
leaq -7(%r12),%rbx
movq $s4z0_info,8(%rbp)
1个减法,1个加法,1个右移,1个异或,没有分支,如预期的那样。使用LLVM后端并没有显著改善运行时。
希望这是有用的,如果你想尝试更多的东西。
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Data.Word
import Data.Int
import Data.Bits
import Control.Arrow ((***))
import Control.DeepSeq (force)
import Control.Exception (evaluate)
import Control.Monad
import System.Random
import Unsafe.Coerce
import Test.QuickCheck hiding ((.&.))
import Criterion.Main
absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 !a !b = max a b - min a b
absdiff1_int16 :: Int16 -> Int16 -> Int16
absdiff1_int16 a b = max a b - min a b
absdiff1_int :: Int -> Int -> Int
absdiff1_int a b = max a b - min a b
absdiff2_int16 :: Int16 -> Int16 -> Int16
absdiff2_int16 a b = xor (v + mask) mask
where v = a - b
mask = unsafeShiftR v 15
absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
!mask = unsafeShiftR v 63
absdiff3_w8 :: Word8 -> Word8 -> Word8
absdiff3_w8 a b = if a > b then a - b else b - a
{-absdiff4_int :: Int -> Int -> Int-}
{-absdiff4_int (I# a) (I# b) =-}
{-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}
e2e :: (Enum a, Enum b) => a -> b
e2e = toEnum . fromEnum
prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
where x' = e2e x
y' = e2e y
check = quickCheck prop_same1
>> quickCheck prop_same2
instance (Random x, Random y) => Random (x, y) where
random gen1 =
let (x, gen2) = random gen1
(y, gen3) = random gen2
in ((x,y),gen3)
main =
do check
!pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
defaultMain
[ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
, bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
, bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
]
, bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
, bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
]
{-, bgroup "absdiff_Int" [ bench "1" $ whnf (absdiff1_int 13) 14-}
{-, bench "2" $ whnf (absdiff3_int 13) 14-}
{-]-}
]
我从聚合魔法中找到了一个快速计算最大值的技巧。唯一的问题是整数,尽管我做了一些尝试,但我不知道如何为无符号整数创建一个版本。 有什么建议吗? 编辑 不要使用它,因为正如其他人所说,它会产生未定义的行为。对于任何现代体系结构,编译器都能够从返回(a)发出无分支条件移动指令
我正在简单的C程序中试验无符号int数据类型和主方法参数。作为一个实验,我写了一个程序,从命令行获取一个int数作为main方法的参数,并对该数和0之间的每个整数求和。 例如,程序计算 f(n) = (1 2 3... n) 当 n 时有效 我开始注意到的第一件事是当f(n) 我手动发现数学上的最大值,我的程序生成的结果将是有效的(例如,在整数溢出之前),对于有符号整数为65535,对于无符号in
我正在读一篇关于整数安全性的文章。以下是链接:http://ptgmedia.pearsoncmg.com/images/0321335724/samplechapter/seacord_ch05.pdf 在第166页,有这样一句话: 涉及无符号操作数的计算永远不会过流,因为不能由结果无符号整数类型表示的结果将被模化为比结果类型可以表示的最大值大一的数字。 这是什么意思?感谢您的回复。
全新的汇编需要一些无符号算术方面的帮助。从C程序转换是什么意思。 使用: Linux操作系统 美国国家科学院 x86(32位) 我想从用户那里读入一个数字。我希望这个号码没有签名。当我输入一个超过有符号整数限制的数字并使用信息寄存器时,我注意到我的寄存器存储的是负数,这意味着发生了溢出。(显然输入的数字低于max unsigned int)如何将此寄存器视为无符号,以便根据结果进行比较和跳转?
上个小节我们主要学习了 Go 语言中的整型 int 数据类型,本小节主要介绍了 Go 语言中处理无符号的整数的数据类型。 1. 定长类型 序号 类型 长度 1 uint8 0~255 2 uint16 0~65535 3 uint32 0~4294967295 4 uint64 0~18446744073709551615 2. 不定长类型 在 Go 语言中也实现了随着平台位数变化而变化的数据类型
问题内容: 我正在使用Flask-SQLAlchemy(MySQL)将门户迁移到Flask。以下是我用于为现有门户创建数据库的代码: 这是我尝试在SQLAlchemy中使用它的方式: 我的问题是,如何使SQLAlchemy模型指定为无符号整数? 问题答案: 在这种情况下(如MySQL这样的数据库其数据类型本身不是标准数据类型或具有非标准选项),你可以通过获取特定于方言的类型来访问这些类型/选项。对