SIMDのシフト演算のお話

SSE2やAVX2のレジスタ全体をシフトする命令はコンパイル時定数しかパラメータに受けとるものしか無いので、目的の処理を書くのに困る事がよくある。

XMMレジスタの場合はコンパイル時定数だけれど、128bit丸ごとを指定のバイト単位でシフトが出来る。

_mm_bslli_si128 pslldq 
_mm_bsrli_si128 psrldq

YMMレジスタの場合はコンパイル時定数だけれど、128bitレーンそれぞれをバイト単位でシフト出来る。

_mm256_bslli_epi128 vpslldq
_mm256_bsrli_epi128 vpsrldq

YMMレジスタをレーン跨いでシフト出来る方法をStackOverflowで発見。

http://stackoverflow.com/questions/20775005/8-bit-shift-operation-in-avx2-with-shifting-in-zeros

template <unsigned int N>
__m256i _mm256_shift_left(__m256i a)
{
	__m256i mask = _mm256_permute2x128_si256(a, a, _MM_SHUFFLE(0,0,3,0) );
	return _mm256_alignr_epi8(a,mask,16-N);
}

static
void print_mm256_bytes(__m256i bytes)
{
	for (int i = 0; i < 32; i++) {
		printf("%2d ",((unsigned char *)&bytes)[i]);
	}
	printf("\n");
}

void test()
{
	__m256i reg =  _mm256_set_epi8(32,31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15, 14,13,12,11,10,9,8,7,6,5,4,3,2,1);
	__m256i result;
	
	print_mm256_bytes(_mm256_shift_left<0>(reg));
	print_mm256_bytes(_mm256_shift_left<1>(reg));
	print_mm256_bytes(_mm256_shift_left<2>(reg));
	print_mm256_bytes(_mm256_shift_left<3>(reg));
	print_mm256_bytes(_mm256_shift_left<4>(reg));
	print_mm256_bytes(_mm256_shift_left<5>(reg));
	print_mm256_bytes(_mm256_shift_left<6>(reg));
	print_mm256_bytes(_mm256_shift_left<7>(reg));
	print_mm256_bytes(_mm256_shift_left<8>(reg));
	print_mm256_bytes(_mm256_shift_left<9>(reg));
	print_mm256_bytes(_mm256_shift_left<10>(reg));
	print_mm256_bytes(_mm256_shift_left<11>(reg));
	print_mm256_bytes(_mm256_shift_left<12>(reg));
	print_mm256_bytes(_mm256_shift_left<13>(reg));
	print_mm256_bytes(_mm256_shift_left<14>(reg));
	print_mm256_bytes(_mm256_shift_left<15>(reg));
	print_mm256_bytes(_mm256_shift_left<16>(reg));
}

実行してみると、確かに左シフトされた結果が得られた。

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
 0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
 0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
 0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
 0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
 0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
 0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22
 0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21
 0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
 0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16

バイト単位でコンパイル時定数なのは、元々備わっている命令も同様だし、しょうがない気もする。2命令だけで実現出来ているのだし…。
範囲が 0〜16までなのは少し不便だけれど、それ以上の場合には違う処理を呼び出すようにすれば良いだろう。どうせコンパイル時定数なんだから。

何故 _mm256_permute2x128_si256 と _mm256_alignr_epi8 で実現出来ているのか解らないので解読してみる。

SELECT4(src1, src2, control)
{
	CASE(control[1:0])
	0: tmp[127:0] := src1[127:0]
	1: tmp[127:0] := src1[255:128]
	2: tmp[127:0] := src2[127:0]
	3: tmp[127:0] := src2[255:128]
	ESAC
	
	IF control[3]
		tmp[127:0] := 0
	FI
	
	RETURN tmp[127:0]
}

__m256i _mm256_permute2x128_si256(__m256i a, __m256i b, const int imm8)
{
	dst[127:0] := SELECT4(a[255:0], b[255:0], imm8[3:0])
	dst[255:128] := SELECT4(a[255:0], b[255:0], imm8[7:4])
	dst[MAX:256] := 0
}

__m256i _mm256_alignr_epi8 (__m256i a, __m256i b, const int count)
{
	FOR j := 0 to 1
		i := j*128
		tmp[255:0] := ((a[i+127:i] << 128) OR b[i+127:i]) >> (count[7:0]*8)
		dst[i+127:i] := tmp[127:0]
	ENDFOR
	dst[MAX:256] := 0
}

template <unsigned int N>
__m256i _mm256_shift_left(__m256i a)
{
	__m256i mask = _mm256_permute2x128_si256(a, a, _MM_SHUFFLE(0,0,3,0) );
	return _mm256_alignr_epi8(a,mask,16-N);
}

疑似コードを見てもよくわからなかったのでデバッガで値の変化を見てみたところ理解出来た。

_mm256_permute2x128_si256 で
mask.m256i_i8[0:15] = 0
mask.m256i_i8[15:31] = a.m256i_i8[0:15]
となり、同じindexのレーンに隣り合う128bitが配置されるので
_mm256_alignr_epi8 で funnel shiftして完了

17〜32バイトのシフトにも対応出来ないか試してみて出来たけれど、なんかコードの見た目が酷い。


#include <stdio.h>
#include <intrin.h>
#include <immintrin.h>

template<bool> struct Range;

template<unsigned int N, typename = Range<true> >
struct mm256_shift_left_impl
{};

template<unsigned int N>
struct mm256_shift_left_impl<N, Range<(0 <= N && N <= 16)> >
{
	static __m256i doit(__m256i a)
	{
		__m256i mask = _mm256_permute2x128_si256(a, a, _MM_SHUFFLE(0,0,3,0) );
		return _mm256_alignr_epi8(a,mask,16-N);
	}
};

template<unsigned int N>
struct mm256_shift_left_impl<N, Range<(16 < N && N <= 32)> >
{
	static __m256i doit(__m256i a)
	{
		__m256i y1 = _mm256_slli_si256(a, N - 16);
		return _mm256_permute2x128_si256(y1, y1, _MM_SHUFFLE(0,0,3,0) );
	}
};

template <unsigned int N>
__m256i mm256_shift_left(__m256i a)
{
	return mm256_shift_left_impl<N>::doit(a);
}

static
void print_mm256_bytes(__m256i bytes)
{
	for (int i = 0; i < 32; i++) {
		printf("%2d ",((unsigned char *)&bytes)[i]);
	}
	printf("\n");
}

template <unsigned int N>
void test(__m256i reg)
{
	print_mm256_bytes(mm256_shift_left<N>(reg));
}


void test()
{
	__m256i reg =  _mm256_set_epi8(
		32,31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,
		16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1
	);
	
	test<0>(reg);
	test<1>(reg);
	test<2>(reg);
	test<3>(reg);
	test<4>(reg);
	test<5>(reg);
	test<6>(reg);
	test<7>(reg);
	test<8>(reg);
	test<9>(reg);
	test<10>(reg);
	test<11>(reg);
	test<12>(reg);
	test<13>(reg);
	test<14>(reg);
	test<15>(reg);
	test<16>(reg);
	test<17>(reg);
	test<18>(reg);
	test<19>(reg);
	test<20>(reg);
	test<21>(reg);
	test<22>(reg);
	test<23>(reg);
	test<24>(reg);
	test<25>(reg);
	test<26>(reg);
	test<27>(reg);
	test<28>(reg);
	test<29>(reg);
	test<30>(reg);
	test<31>(reg);
	test<32>(reg);
}

int main(int argc, char* argv[])
{
	test();
	return 0;
}

処理結果
 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
 0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
 0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
 0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
 0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
 0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
 0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22
 0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21
 0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
 0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12 13
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11 12
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10 11
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9 10
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8  9
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7  8
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6  7
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5  6
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4  5
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3  4
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2  3
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  2
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0