ハッシュ表にトライ(Trie on Hash)

ハッシュ表に挑戦したという意味ではなく,「ハッシュ表の上にトライを構築してみました」という話です.ダブル配列のアイデア(CHECK に遷移元のインデックスを保存)を利用しています.

ダブル配列は,小さな整数(基本的に Byte)をラベルとする場合には高性能なのですが,ラベルが大きくなると配置問題が難しくて悩みものです.配列がスパースになって,さらに構築も遅くなるので,できれば使いたくありません.

というわけで,ハッシュ表にトライをのっけてみました.これなら,ハッシュ関数を定義できるような任意のオブジェクトをラベルとして利用できます.

# 文字 n-gram の頻度集計を目的として作成しました.考えてみると,トリプル配列 → ダブル配列 → トリプル配列という流れになっています.

#ifndef ALGO_HASH_TRIE_H
#define ALGO_HASH_TRIE_H

#include <algorithm>
#include <functional>
#include <stack>
#include <vector>

namespace algo {

template <typename LabelType, typename ValueType, typename HashFuncType,
	class EqualFuncType = std::equal_to<LabelType> >
class HashTrie
{
public:
	typedef LabelType Label;
	typedef ValueType Value;
	typedef HashFuncType HashFunc;
	typedef EqualFuncType EqualFunc;

	typedef int Index;
	typedef std::pair<Index, bool> Result;

	class Node
	{
	public:
		Node() : from_(NONE), label_(), value_() {}

		void set_from(Index from) { from_ = from; }
		void set_label(const Label &label) { label_ = label; }
		void set_value(const Value &value) { value_ = value; }

		Index from() const { return from_; }
		const Label &label() const { return label_; }
		const Value &value() const { return value_; }

	private:
		Index from_;
		Label label_;
		Value value_;

		// Copyable.
	};

	typedef std::vector<Node> Table;

	static const Index NONE = -1;
	static const Index INITIAL_TABLE_SIZE = 2;

public:
	explicit HashTrie(const HashFunc &hash = HashFunc())
		: table_(), num_keys_(), hash_(hash), equal_(EqualFunc()) { Clear(); }

	// Adds transitions corresponding to the given labels.
	template <typename LabelRangeType>
	Result Add(Index from, const LabelRangeType &label_range)
	{
		Result result(from, false);
		for (typename LabelRangeType::const_iterator it = label_range.begin();
			it != label_range.end(); ++it)
			result = Add(result.first, *it);
		return result;
	}
	// Searches the next node, and if it does not exists, adds a new node.
	Result Add(Index from, const Label &label)
	{
		Index next = FindNext(from, label);
		bool is_new_node = table_[next].from() == NONE;
		if (is_new_node)
		{
			table_[next].set_from(from);
			table_[next].set_label(label);
			++num_keys_;

			// Expands the hash table if the filling rate >= 75%.
			Index upper_limit = table_size() - (table_size() >> 2);
			if ((num_keys_ + 1) >= upper_limit)
				ExpandHashTable(table_size() << 1, &next);
		}
		return Result(next, is_new_node);
	}

	// Follows transitions corresponding to the given labels.
	template <typename LabelRangeType>
	Index Find(Index from, const LabelRangeType &label_range) const
	{
		Index index = from;
		for (typename LabelRangeType::const_iterator it = label_range.begin();
			index != NONE && it != label_range.end(); ++it)
			index = Find(index, *it);
		return index;
	}
	// Searches the next node, and if it does not exists, returns NONE.
	Index Find(Index from, const Label &label) const
	{
		Index next = FindNext(from, label);
		if (table_[next].from() != NONE)
			return next;
		return NONE;
	}

	// Finds the next valid index.
	Index Next(Index index) const
	{
		if (index == NONE || index >= table_size())
			return NONE;
		while (++index < table_size())
		{
			if (table_[index].from() != NONE)
				return index;
		}
		return NONE;
	}
	// Restores an index list.
	template <typename IndexContainerType>
	void Restore(Index index, IndexContainerType *index_container) const
	{
		index_container->clear();
		while (index != 0)
		{
			index_container->push_back(index);
			index = table_[index].from();
		}
		std::reverse(index_container->begin(), index_container->end());
	}

	// Accesses a node of HashTrie (writable).
	Node &operator[](Index index) { return table_[index]; }
	// Accesses a node of HashTrie (read-only).
	const Node &operator[](Index index) const { return table_[index]; }

	// Returns the number of keys added.
	Index num_keys() const { return num_keys_; }
	// Returns the table size.
	Index table_size() const { return static_cast<Index>(table_.size()); }

	// Expands the hash table.
	void Reserve(Index new_table_size)
	{
		ExpandHashTable(new_table_size, NULL);
	}

	// Clears the table.
	void Clear()
	{
		Index current_table_size = std::max(INITIAL_TABLE_SIZE, table_size());
		table_.clear();
		table_.resize(current_table_size);
		table_[0].set_from(0);
		num_keys_ = 0;
	}

	// Swaps objects.
	void Swap(HashTrie *target)
	{
		table_.swap(target->table_);
		std::swap(num_keys_, target->num_keys_);
		std::swap(hash_, target->hash_);
		std::swap(equal_, target->equal_);
	}

private:
	Table table_;
	Index num_keys_;
	HashFunc hash_;
	EqualFunc equal_;

	// Disallows copies.
	HashTrie(const HashTrie &);
	HashTrie &operator=(const HashTrie &);

	// Expands the hash table.
	void ExpandHashTable(Index new_table_size, Index *watched_index)
	{
		std::stack<Index> index_stack;
		std::vector<Index> from_table(table_size(), NONE);

		Index old_table_size = table_size();
		Table old_table(new_table_size);
		table_.swap(old_table);
		table_[0].set_from(0);
		from_table[0] = 0;

		for (Index index = 1; index < old_table_size; ++index)
		{
			if (old_table[index].from() == NONE)
				continue;

			// Modified (2009-09-18).
			for (Index old_index = index; from_table[old_index] == NONE; )
			{
				index_stack.push(old_index);
				old_index = old_table[old_index].from();
			}

			while (!index_stack.empty())
			{
				Index old_index = index_stack.top();
				const Node &old_node = old_table[old_index];

				Index new_from = from_table[old_node.from()];
				Index new_index = FindNext(new_from, old_node.label());

				table_[new_index].set_from(new_from);
				table_[new_index].set_label(old_node.label());
				table_[new_index].set_value(old_node.value());

				from_table[old_index] = new_index;
				if (watched_index != NULL && old_index == *watched_index)
				{
					*watched_index = new_index;
					watched_index = NULL;
				}
				index_stack.pop();
			}
		}
	}

	// Finds the next node.
	Index FindNext(Index from, const Label &label) const
	{
		unsigned long long source_value =
			(static_cast<unsigned long long>(from) << 32) ^ hash_(label);
		Index next = Hash(source_value) % table_size();
		while (table_[next].from() != NONE)
		{
			if (next != 0)
			{
				if (from == table_[next].from() &&
					equal_(label, table_[next].label()))
					break;
			}
			next = (next + 1) % table_size();
		}
		return next;
	}

	// Calculates 32-bit hash value from a 64-bit integer.
	// http://www.concentric.net/~Ttwang/tech/inthash.htm
	static unsigned int Hash(unsigned long long x)
	{
		x = (~x) + (x << 18);
		x = x ^ (x >> 31);
		x = x * 21;
		x = x ^ (x >> 11);
		x = x + (x << 6);
		x = x ^ (x >> 22);
		return x;
	}
};

}  // namespace algo

#endif  // ALGO_HASH_TRIE_H

追記(2009-09-18):関数 ExpandHashTable のバグを修正しました.

整数をラベルにする場合,ハッシュ関数はラベルをそのまま返すだけで問題ありません.内部で from との合成にハッシュ関数を用いるため,むしろ冗長になってしまいます.ユーザ定義のハッシュ関数が返すべき型については,整数であれば何でも大丈夫なはずです.

また,トライのルートはインデックス 0 と対応しています.Add() や Find() でルートを始点にする場合,from に 0 を指定してください.Next() を使うときも,最初に 0 で呼び出してあげれば OK です.

# 毎度,n-yo さんにはお世話になっています.