マージ用に優先順序付きキューを少しだけ効率化

データの規模が大きくなってマージのコストが見過ごせなくなってきたため,少しでも効率を良くするべく,優先順序付きキュー(std::priority_queue)に手をつけてみました.

# 最後のマージは並列化できないので深刻な問題です.後,ヒープは実装が楽だから,というのも理由の一つです.

std::priority_queue を用いたマージで気になる点は,ヒープキューなら pop() と push() をまとめて処理できるはずなのに,pop() と push() を別々に呼び出さなければならないことです.

  • pop(): 先頭の要素を削除します.
    • 先頭の要素を取り除いた後,下にある要素を上にずらすことで隙間を埋めます.
  • push(): 新しい要素を追加します.
    • 末尾に新しい要素を付け足した後,その要素を上に移動することでヒープにおける要素の大小関係を維持します.
  • pop() + push(): 先頭の要素を削除して,新しい要素を追加します.
    • 先頭の要素を新しい要素で置き換えた後,その要素を下に移動することでヒープにおける要素の大小関係を維持します.

つまり,pop() と push() を別々に呼び出すと要素を 2 回も移動しなければならないのに対して,pop() と push() をまとめれば要素の移動を 1 回に抑えられます.結果として,多少は速くなるはずというわけです.

そして,実際に試してみたところ,マージにかかる時間を約 25% 短縮することができました.入力は gzip で圧縮されているファイルが 11 個で,各ファイルのサイズは 100MB 前後,マージは行単位という設定(N-gram コーパスのマージ処理)です.

// std::priority_queue を使った場合
$ time merge-ngrams-old ngrams/*.gz > /dev/null
real  4m9.242s
user  4m45.300s
sys  0m8.590s

// 自作の HeapQueue を使った場合
$ time merge-ngrams-new ../web-corpus/ngrams/*.gz > /dev/null
real  3m0.163s
user  3m41.420s
sys  0m8.180s


今回の実験に用いた優先順序付きキューの実装は以下のようになっています.手抜きしているので,使う場合は要注意です.ポイントは,popPush() を使うことにより,pop() と push() を別々に呼び出す無駄をなくせることです.

#include <functional>
#include <utility>
#include <vector>

// T:
//  キューに格納する要素の型を指定します.
// LessThan:
//  T を比較するための関数を指定します.
//  デフォルトでは T::operator<() を使います.
template <typename T, typename LessThan = std::less<T> >
class HeapQueue
{
public:
  HeapQueue() : buf_() {}
  ~HeapQueue() {}

  // キューを空っぽの状態にします.
  void clear() { buf_.clear(); }

  // 新しい要素を追加します.
  void push(const T &value);
  // 先頭の要素を削除します.
  void pop();

  // 先頭の要素を削除して,新しい要素を追加します.
  void popPush(const T &value);

  // 先頭の要素を取得します.削除はしません.
  const T &top() const { return buf_[getRootIndex()]; }

  // キューが空っぽなら true,空っぽじゃなければ false を返します. 
  bool empty() const { return buf_.empty(); }
  // キューに格納されている要素の数を返します.
  std::size_t size() const { return buf_.size(); }

private:
  std::vector<T> buf_;

  static std::size_t getRootIndex() { return 0; }
  static std::size_t getParentIndex(std::size_t index)
  { return (index - 1) / 2; }
  static std::size_t getChildIndex(std::size_t index)
  { return (index * 2) + 1; }

  // Disallows copies.
  HeapQueue(const HeapQueue &);
  HeapQueue &operator=(const HeapQueue &);
};

template <typename T, typename LessThan>
void HeapQueue<T, LessThan>::push(const T &value)
{
  std::size_t index = buf_.size();
  buf_.resize(buf_.size() + 1);

  // 末尾の要素から上方向に辿り,新しい要素の格納場所を用意します.
  while (index > getRootIndex())
  {
    std::size_t parent_index = getParentIndex(index);
    if (!LessThan()(value, buf_[parent_index]))
      break;

    buf_[index] = buf_[parent_index];
    index = parent_index;
  }
  buf_[index] = value;
}

template <typename T, typename LessThan>
void HeapQueue<T, LessThan>::pop()
{
  const T &value = buf_.back();
  popPush(value);
  buf_.pop_back();
}

template <typename T, typename LessThan>
void HeapQueue<T, LessThan>::popPush(const T &value)
{
  // ルートから下方向に辿り,新しい要素の格納場所を用意します.
  std::size_t index = getRootIndex();
  for ( ; ; )
  {
    std::size_t child_index = getChildIndex(index);
    if (child_index >= buf_.size())
      break;

    if (child_index + 1 < buf_.size() &&
      LessThan()(buf_[child_index + 1], buf_[child_index]))
      ++child_index;

    if (!LessThan()(buf_[child_index], value))
      break;

    buf_[index] = buf_[child_index];
    index = child_index;
  }
  buf_[index] = value;
}