Inspired by Roberts answer, polynomial multiplication in Arm Neon can be utilised to scatter the bits --
inline poly8x16_t mull_lo(poly8x16_t a) {
auto b = vget_low_p8(a);
return vreinterpretq_p8_p16(vmull_p8(b,b));
}
inline poly8x16_t mull_hi(poly8x16_t a) {
auto b = vget_high_p8(a);
return vreinterpretq_p8_p16(vmull_p8(b,b));
}
auto a = mull_lo(word);
auto b = mull_lo(a), c = mull_hi(a);
auto d = mull_lo(b), e = mull_hi(b);
auto f = mull_lo(c), g = mull_hi(c);
Then the vsli can be used to combine the bits pairwise.
auto ab = vsli_p8(vget_high_p8(d), vget_low_p8(d), 1);
auto cd = vsli_p8(vget_high_p8(e), vget_low_p8(e), 1);
auto ef = vsli_p8(vget_high_p8(f), vget_low_p8(f), 1);
auto gh = vsli_p8(vget_high_p8(g), vget_low_p8(g), 1);
auto abcd = vsli_p8(ab, cd, 2);
auto efgh = vsli_p8(ef, gh, 2);
return vsli_p8(abcd, efgh, 4);
Clang optimizes this code to avoid vmull2 instructions, using heavily ext q0,q0,8 to vget_high_p8.
An iterative approach would possibly be not only faster, but also uses less registers and also simdifies for 2x or more throughput.
// transpose bits in 2x2 blocks, first 4 rows
// x = a b|c d|e f|g h a i|c k|e m|g o | byte 0
// i j|k l|m n|o p b j|d l|f n|h p | byte 1
// q r|s t|u v|w x q A|s C|u E|w G | byte 2
// A B|C D|E F|G H r B|t D|v F|h H | byte 3 ...
// ----------------------
auto a = (x & 0x00aa00aa00aa00aaull);
auto b = (x & 0x5500550055005500ull);
auto c = (x & 0xaa55aa55aa55aa55ull) | (a << 7) | (b >> 7);
// transpose 2x2 blocks (first 4 rows shown)
// aa bb cc dd aa ii cc kk
// ee ff gg hh -> ee mm gg oo
// ii jj kk ll bb jj dd ll
// mm nn oo pp ff nn hh pp
auto d = (c & 0x0000cccc0000ccccull);
auto e = (c & 0x3333000033330000ull);
auto f = (c & 0xcccc3333cccc3333ull) | (d << 14) | (e >> 14);
// Final transpose of 4x4 bit blocks
auto g = (f & 0x00000000f0f0f0f0ull);
auto h = (f & 0x0f0f0f0f00000000ull);
x = (f & 0xf0f0f0f00f0f0f0full) | (g << 28) | (h >> 28);
In ARM each step can now be composed with 3 instructions:
auto tmp = vrev16_u8(x);
tmp = vshl_u8(tmp, plus_minus_1); // 0xff01ff01ff01ff01ull
x = vbsl_u8(mask_1, x, tmp); // 0xaa55aa55aa55aa55ull
tmp = vrev32_u16(x);
tmp = vshl_u16(tmp, plus_minus_2); // 0xfefe0202fefe0202ull
x = vbsl_u8(mask_2, x, tmp); // 0xcccc3333cccc3333ull
tmp = vrev64_u32(x);
tmp = vshl_u32(tmp, plus_minus_4); // 0xfcfcfcfc04040404ull
x = vbsl_u8(mask_4, x, tmp); // 0xf0f0f0f00f0f0f0full