BitVector の select に SSE2 を使ってみる
marisa-trie に SSE2, SSSE3 を使ってみた - やた@はてな日記,marisa-trie に SSE2, SSSE3 を使ってみた(続き) - やた@はてな日記 の続きです.
はじめに
BitVector の select に SSE2 を導入することで高速化したわけですが,本記事では select のどこで SSE2 命令を使うようにしたのかを説明します.
まずは,本記事における BitVector と select の定義を説明します.それから,select の実装について軽く紹介してから,SSE2 命令を導入した箇所を説明して,導入前後の実装を比較します.
定義
BitVector
BitVector はビット (0/1) の配列です.本記事では,64-bit 環境を想定しているので,64-bit 整数の配列によって実現します.たとえば,長さ 65,536 bits の BitVector であれば,64-bit 整数 1,024 個からなる配列により表現されます.
なお,64-bit 整数の最下位ビット(LSB)を先頭のビット,最上位ビット(MSB)を末尾のビットと考えます.
// Return the "i"th bit. std::uint64_t get_ith_bit(const std::uint64_t *bits, std::size_t i) { return (bits[(i - 1) / 64] >> ((i - 1) % 64)) & 1; }
select
select は BitVector に含まれる i 番目の 1 になっているビットを検索して,その位置を返す関数です.marisa-trie においては,select の実行時間が検索時間に占める割合が高いため,select の高速化が検索速度の向上につながります.
// Return the position of the "ith" 1 from "bits". std::size_t select(const std::uint64_t *bits, std::size_t i); // For example, if "bits" are "10010110010...", // select(bits, 1) returns 1, select(bits, 2) returns 4, // and select(bits, 3) returns 6.
SSE2 の導入箇所
select の実装においては,BitVector 全体から徐々に範囲を絞り込んでいき,最終的には 64-bit 整数に含まれる i 番目の 1 になっているビットを検索する問題に行き着きます.そして,この最終段階において SSE2 を導入します.
// Find the position of the "i"th 1 in a 64-bit integer "bits". std::size_t find_ith_1_bit(std::uint64_t bits, std::size_t i); // For example, // find_ith_1_bit(1 << 8, 1) returns 9, find_ith_bit(-1, 64) returns 64, // and find_ith_1_bit(0x1111, 2) returns 5.
以降,find_ith_1_bit に絞って議論を進めていきます.
旧実装(SSE2 なし)
SSE2 導入前の実装では,下位 8, 16, 24, ..., 64 bits に含まれる 1 の数を計算し,それらを使って二分探索をおこなうことで,8 bits まで範囲を絞り込んでいました.8 bits まで絞り込んだ後は,事前に計算しておける程度の組み合わせしか存在しないため,あらかじめ計算しておいたテーブルを使って解決します.
本記事のポイントは,二分探索の条件分岐です.分岐予測に失敗したときのコストが 20 clocks くらいと考えれば,予測の難しい分岐は厄介な存在であり,分岐を取り除くことができれば速くなりそうだと考えられます.
#include <cstdint> constexpr std::uint64_t MASK_55 = 0x5555555555555555ULL; constexpr std::uint64_t MASK_33 = 0x3333333333333333ULL; constexpr std::uint64_t MASK_0F = 0x0F0F0F0F0F0F0F0FULL; constexpr std::uint64_t MASK_01 = 0x0101010101010101ULL; constexpr std::uint8_t SELECT_TABLE[8][256] = { // Omitted. }; std::size_t find_ith_1_bit(std::uint64_t bits, std::size_t i) { // See http://en.wikipedia.org/wiki/Hamming_weight for details. std::uint64_t counts; counts = bits - ((bits >> 1) & MASK_55); counts = (counts & MASK_33) + ((counts >> 2) & MASK_33); counts = (counts + (counts >> 4)) & MASK_0F; counts *= MASK_01; // Binary search. std::uint8_t offset; if (i <= ((counts >> 24) & 0xFF)) { if (i <= ((counts >> 8) & 0xFF)) { if (i <= (counts & 0xFF)) { offset = 0; } else { offset = 8; bits >>= 8; i -= (counts & 0xFF); } } else if (i <= ((counts >> 16) & 0xFF)) { offset = 16; bits >>= 16; i -= ((counts >> 8) & 0xFF); } else { offset = 24; bits >>= 24; i -= ((counts >> 16) & 0xFF); } } else if (i <= ((counts >> 40) & 0xFF)) { if (i <= ((counts >> 32) & 0xFF)) { offset = 32; bits >>= 32; i -= ((counts >> 24) & 0xFF); } else { offset = 40; bits >>= 40; i -= ((counts >> 32) & 0xFF); } } else if (i <= ((counts >> 48) & 0xFF)) { offset = 48; bits >>= 48; i -= ((counts >> 40) & 0xFF); } else { offset = 56; bits >>= 56; i -= ((counts >> 48) & 0xFF); } return offset + SELECT_TABLE[i - 1][bits & 0xFF] + 1; }
新実装(SSE2 あり)
旧実装に SSE2 を導入して条件分岐を取り除いたのが新実装です.単純に見比べてやれば,条件分岐がなくなっていることが分かります.SSE2 命令と対応する関数は以下の通りです.
- r = _mm_cvtsi64_si128(a)
- XMM レジスタの下位 64-bit に a を格納して,上位 64-bit を 0 で埋めます.
- r.qwords[0] = a, r.qwords[1] = 0;
- http://msdn.microsoft.com/ja-jp/library/bb514072%28v=vs.100%29.aspx
- r = _mm_cmpgt_epi8(a, b)
- a, b の各 byte を 8-bit 符号あり整数として比較します.
- 返り値は a > b となる byte を 0xFF で埋め,それ以外の部分を 0x00 で埋めたものとなります.
- r.bytes[i] = (a.bytes[i] > b.bytes[i]) ? 0xFF : 0x00;
- http://msdn.microsoft.com/ja-jp/library/wf45zt2b%28v=vs.100%29.aspx
- r = _mm_movemask_epi8(a)
- a を構成する 16 bytes から MSB のみを集めて 16-bit 整数を作成します.
- r.bits[i] = (a.bytes[i] & 0x80) ? 1 : 0;
- http://msdn.microsoft.com/ja-jp/library/s090c8fk%28v=vs.100%29.aspx
i の値域は 0 以上 64 以下なので,i * MASK_01 は i の下位 8-bit を 8 個並べてできる 64-bit 整数になります.一方,counts には "bits" の下位 8, 16, 24, ..., 64 bits に含まれる 1 の数が格納されています.そのため,これらの各 byte を比較することで,どの byte に i 番目の 1 が含まれるのかを求めることができます.
#include <emmintrin.h> // For SSE2 instructions. constexpr std::uint8_t POPCNT_TABLE[256] = { // Omitted. }; std::size_t find_ith_1_bit(std::uint64_t bits, std::size_t i) { // See http://en.wikipedia.org/wiki/Hamming_weight for details. std::uint64_t counts; counts = bits - ((bits >> 1) & MASK_55); counts = (counts & MASK_33) + ((counts >> 2) & MASK_33); counts = (counts + (counts >> 4)) & MASK_0F; counts *= MASK_01; __m128i x = _mm_cvtsi64_si128(i * MASK_01); __m128i y = _mm_cvtsi64_si128(counts); x = _mm_cmpgt_epi8(x, y); const std::uint8_t offset = POPCNT_TABLE[_mm_movemask_epi8(x)]; bits >>= offset; i -= ((counts << 8) >> offset) & 0xFF; return offset + SELECT_TABLE[i - 1][bits & 0xFF] + 1; }
おわりに
途中で PMOVMSKB より BSR を使った方が速そうなことに気づいたのですが,時間もないのでそのままにしました.次回,BSR を使った結果を記事にしようと思います.
(追記 2013-02-07)分岐をなくせば速くなるという保証はありません.たとえば,入力パターンが特殊で分岐予測が失敗しないのであれば,SSE2 を使わないほうが速くなりそうです.
おまけ
本記事のソースコードにおいて省略したテーブルの中身を以下に示します.
constexpr std::uint8_t SELECT_TABLE[8][256] = { { 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0 }, { 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1, 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1, 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1, 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1, 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1 }, { 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2, 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2, 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2, 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2, 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2, 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2, 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2, 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2, 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2, 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2, 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2, 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2, 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2 }, { 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3, 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3, 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3, 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3, 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3, 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3, 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3, 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3, 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3 }, { 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4 }, { 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5 }, { 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6 }, { 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 } };
constexpr std::uint8_t POPCNT_TABLE[256] = { 0, 8, 8, 16, 8, 16, 16, 24, 8, 16, 16, 24, 16, 24, 24, 32, 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, 32, 40, 40, 48, 40, 48, 48, 56, 40, 48, 48, 56, 48, 56, 56, 64 };