背景
推荐排序 dnn 模型中经常会用一些特别稀疏的 id 特征,此时需要对这些 id 特征做 embedding 操作,一般在 tensorflow 中都会使用一个 shape=[id_index_size, embedding_size]的 Variable 矩阵做 embedding 参数,然后根据 id 特征的 index 去 Variable 矩阵中查表得到相应的 embedding 表示。这里需要注意的是:id_index_size 的大小一般都不会等于对应 id table 的元素个数,因为有很多 id 元素不在原始的 id table 表中,比如新上架的一些商品等。此时需要将 id_index_size 设置的大一些,以留一些位置给那些不在 id table 表的元素使用。那么从原始的 id 特征值映射成[0, id_index_size)间的 index 是怎么做的呢?
如果 id 量非常小的话,可在特征提取后把 id 排序一遍生成从 0 开始的连续 id 值,但在工业界的场景下 id 特征的量级往往是百万到上亿级别,很难做排序。幸好 tensorFlow 内部有一个函数可以将原始的 id 特征值映射到从 0 开始的 index,这个函数就是 lookup.index_table_from_tensor。
例子
首先举一个 index_table_from_tensor 的使用例子:
import tensorflow as tf
sess = tf.Session()
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
table = tf.contrib.lookup.index_table_from_tensor(
vocabulary_list=vocabulary_list, num_oov_buckets=10, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer", "dad", "mom", "hello"])
table.init.run(session=sess)
ids = table.lookup(features)
print(sess.run(ids))
返回值:[ 0 1 10 2 9 3 9]
index_table_from_tensor 函数的作用是:如果在 vocabulary_list 里面的 id 特征值,则映射为从 0 到 len(table)-1 之间的 index,如果不在 vocabulary_list 里面,则通过 hash 函数映射到 len(table)至 len(table) + num_oov_buckets - 1 之间的区间中。
我们开始分析源代码,首先看下 index_table_from_tensor 的源码:
def index_table_from_tensor(vocabulary_list,
num_oov_buckets=0,
default_value=-1,
hasher_spec=FastHashSpec,
dtype=dtypes.string,
name=None):
"""Returns a lookup table that converts a string tensor into int64 IDs.
This operation constructs a lookup table to convert tensor of strings into
int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
tensor where each element is a key and corresponding index within the tensor
is the value.
Args:
vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
indices. The type of this object must be castable to `dtype`.
num_oov_buckets: The number of out-of-vocabulary buckets.
default_value: The value to use for out-of-vocabulary feature values.
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignment of out-of-vocabulary buckets.
dtype: The type of values passed to `lookup`. Only string and integers are
supported.
name: A name for this op (optional).
Returns:
The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
ValueError: If `vocabulary_list` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
“”“
类型检查代码
”“”
with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
# 获取由传入的vocabulary_list构造的hash表,keys是vocabulary_list里面的元素,values是
# 【0, len(num_elements)-1】
keys = ops.convert_to_tensor(vocabulary_list)
num_elements = array_ops.size(keys)
values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = ""
with ops.name_scope(None, "hash_table") as hash_table_scope:
table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
init = KeyValueTensorInitializer(
table_keys,
values,
table_keys.dtype.base_dtype,
dtypes.int64,
name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
if num_oov_buckets:
table = IdTableWithHashBuckets(
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
name=feat_to_id_scope,
key_dtype=dtype)
return table
从上面的代码片段可以看出,index_table_from_tensor 函数体中有三个比较重要的步骤:
根据 keys 和 values 创建 KeyValueTensorInitializer。
根据 KeyValueTensorInitializer 创建具体的 HashTable 类,对于不在 vocabulary_list 表里面的元素,统一返回默认值。
如果 num_oov_buckets>0,则会创建带有 Buckets 的 IdTableWithHashBuckets 类,用一个 hash 函数返回那些不在 vocabulary_list 表里面的 id 元素的 index 值。
首先我们来分析第一步的操作,先通过传入的 vocabulary_list 去构造 keys 和 values,然后传入到 KeyValueTensorInitializer 类中返回 table 初始化的 op,注意此时只是返回一个用于 table 初始化的 OP,还没有做具体的初始化工作(初始化要放到具体的 HashTable 中去完成)。下面看下 KeyValueTensorInitializer 的代码:
class KeyValueTensorInitializer(TableInitializerBase):
"""Table initializers given `keys` and `values` tensors."""
def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
with ops.name_scope(name, "key_value_init", [keys, values]) as scope:
self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
self._values = ops.convert_to_tensor(
values, dtype=value_dtype, name="values")
self._name = scope
super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
self._values.dtype)
def initialize(self, table):
"""Initializes the given `table` with `keys` and `values` tensors.
Args:
table: The table to initialize.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
_check_table_dtypes(table, self._keys.dtype, self._values.dtype)
with ops.name_scope(
self._name, values=(table.table_ref, self._keys,
self._values)) as scope:
init_op = gen_lookup_ops.lookup_table_import_v2(
table.table_ref, self._keys, self._values, name=scope)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
可看出table的初始化操作是通过lookup_table_import_v2这个op来完成,下面就分析下这个op的相关c代码:
REGISTER_OP("LookupTableImportV2")
.Input("table_handle: resource")
.Input("keys: Tin")
.Input("values: Tout")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
return Status::OK();
});
// Clear the table and insert data.
class LookupTableImportOp : public OpKernel {
public:
explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
lookup::LookupInterface* table;
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
“”“
....
”“”
const Tensor& keys = ctx->input(1);
const Tensor& values = ctx->input(2);
OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
int memory_used_before = 0;
if (ctx->track_allocations()) {
memory_used_before = table->MemoryUsed();
}
OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
memory_used_before);
}
}
};
这里通过 table_handle 句柄来传入具体的 Table 类,上面举的例子中使用的是 HashTable 类,HashTable 类继承于 InitializableLookupTable 类,InitializableLookupTable 可理解为那些需要初始化操作的 table 集合,而 InitializableLookupTable 又继承于 LookupInterface 类。通过 table->ImportValues()方法来将 vocabulary_list 中构造的 keys 和 values 导入到 hash 表中,完成 hash 表的初始化工作。
通过 LookupTableImportV2 返回 init_op 之后,将该初始化 hash 表的 op 传给 HashTable 类,HashTable 类和其父类的代码如下:
class HashTable(InitializableLookupTableBase):
def __init__(self, initializer, default_value, shared_name=None, name=None):
"""Creates a non-initialized `HashTable` object.
Creates a table, the type of its keys and values are specified by the
initializer.
Before using the table you will have to initialize it. After initialization
the table will be immutable.
Args:
initializer: The table initializer to use. See `HashTable` kernel for
supported key and value types.
default_value: The value to use if a key is missing in the table.
shared_name: If non-empty, this table will be shared under
the given name across multiple sessions.
name: A name for the operation (optional).
Returns:
A `HashTable` object.
"""
with ops.name_scope(name, "hash_table", (initializer,
default_value)) as scope:
table_ref = gen_lookup_ops.hash_table_v2(
shared_name=shared_name,
key_dtype=initializer.key_dtype,
value_dtype=initializer.value_dtype,
name=scope)
super(HashTable, self).__init__(table_ref, default_value, initializer)
self._value_shape = self._default_value.get_shape()
class InitializableLookupTableBase(LookupInterface):
"""Initializable lookup table interface.
An initializable lookup tables persist across different steps.
"""
def __init__(self, table_ref, default_value, initializer):
"""Construct a table object from a table reference.
If requires a table initializer object (subclass of `TableInitializerBase`).
It provides the table key and value types, as well as the op to initialize
the table. The caller is responsible to execute the initialization op.
Args:
table_ref: The table reference, i.e. the output of the lookup table ops.
default_value: The value to use if a key is missing in the table.
initializer: The table initializer to use.
"""
name = table_ref.op.name.split("/")[-1]
super(InitializableLookupTableBase,
self).__init__(initializer.key_dtype, initializer.value_dtype,
name)
self._table_ref = table_ref
self._default_value = ops.convert_to_tensor(
default_value, dtype=self._value_dtype)
self._default_value.get_shape().merge_with(tensor_shape.scalar())
self._init = initializer.initialize(self)
@property
def table_ref(self):
"""Get the underlying table reference."""
return self._table_ref
@property
def default_value(self):
"""The default value of the table."""
return self._default_value
@property
def init(self):
"""The table initialization op."""
return self._init
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
The `default_value` is used for keys not present in the table.
"""
key_tensor = keys
with ops.name_scope(name, "%s_Lookup" % self._name,
(self._table_ref, key_tensor,
self._default_value)) as scope:
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, key_tensor, self._default_value, name=scope)
values.set_shape(key_tensor.get_shape())
if isinstance(keys, sparse_tensor.SparseTensor):
return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
else:
return values
InitializableLookupTableBase 类继承于 HashTable,InitializableLookupTableBase 类中有 init 操作和 look_up,所以表的初始化和 look_up 操作其实都是在父类中完成,其中表的初始化操作上面已经讨论过,我们就重点看下 look_up 操作的相关 OP:
lookup_table_find_v2
// Table lookup op. Perform the lookup operation on the given table.
class LookupTableFindOp : public OpKernel {
public:
explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
lookup::LookupInterface* table;
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
“”“
......
”“”
const Tensor& key = ctx->input(1);
const Tensor& default_value = ctx->input(2);
OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
TensorShape output_shape = key.shape();
output_shape.RemoveLastDims(table->key_shape().dims());
output_shape.AppendShape(table->value_shape());
Tensor* out;
OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value));
}
};
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
LookupTableFindOp);
REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
LookupTableFindOp);
# table->Find调用的是HashTable类的DoFind方法
Status DoFind(const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const V default_val = default_value.flat<V>()(0);
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
*table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
}
return Status::OK();
}
最后我们分析下 IdTableWithHashBuckets 类,该类可用一个 hash 函数返回那些不在 vocabulary_list 表里 id 元素的 index 值,先举个例子:
import tensorflow as tf
num_oov_buckets = 3
input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
table = tf.IdTableWithHashBuckets(
tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),
num_oov_buckets)
out = table.lookup(input_tensor).
table.init.run()
print(out.eval())
#filename依次为"emerson", "lake", "palmer"
#结果为[0, 1, 2, 4, 7]
下面看下 IdTableWithHashBuckets 的源码:
class IdTableWithHashBuckets(LookupInterface):
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
For example, if an instance of `IdTableWithHashBuckets` is initialized with a
string-to-id table that maps:
* `emerson -> 0`
* `lake -> 1`
* `palmer -> 2`
The `IdTableWithHashBuckets` object will performs the following mapping:
* `emerson -> 0`
* `lake -> 1`
* `palmer -> 2`
* `<other term> -> bucket_id`, where bucket_id will be between `3` and
`3 + num_oov_buckets - 1`, calculated by:
`hash(<term>) % num_oov_buckets + vocab_size`
If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
the lookup result is `[0, 1, 2, 4, 7]`.
If `table` is None, only out-of-vocabulary buckets are used.
"""
def __init__(self,
table,
num_oov_buckets,
hasher_spec=FastHashSpec,
name=None,
key_dtype=None):
"""Construct a `IdTableWithHashBuckets` object.
Args:
table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets (optional).
name: A name for the operation (optional).
key_dtype: Data type of keys passed to `lookup`. Defaults to
`table.key_dtype` if `table` is specified, otherwise `tf.string`.
Must be string or integer, and must be castable to `table.key_dtype`.
Raises:
ValueError: when `table` in None and `num_oov_buckets` is not positive.
TypeError: when `hasher_spec` is invalid.
"""
# If a name ends with a '/' it is a "name scope", remove all trailing '/'
# characters to use as table name.
“”“
类型检查
“”“
self._num_oov_buckets = num_oov_buckets
super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64,
name.split("/")[-1])
@property
def init(self):
"""The table initialization op."""
if self._table:
return self._table.init
with ops.name_scope(None, "init"):
return control_flow_ops.no_op()
@property
def table_ref(self):
if self._table is not None:
return self._table.table_ref
return None
def _get_string_to_hash_bucket_fn(self, hasher_spec):
"""Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
if not isinstance(hasher_spec, HasherSpec):
raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
if hasher_spec.hasher == "fasthash":
return string_ops.string_to_hash_bucket_fast
if hasher_spec.hasher == "legacy":
return string_ops.string_to_hash_bucket
if hasher_spec.hasher == "stronghash":
return functools.partial(
string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
def lookup(self, keys, name=None):
values = keys
if self._num_oov_buckets == 0:
ids = self._table.lookup(values, name=name)
else:
# TODO(yleon): Consider moving this functionality to its own kernel.
with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
self._hasher_spec)
buckets = str_to_hash_bucket(
_as_string(values),
num_buckets=self._num_oov_buckets,
name="hash_bucket")
if self._table:
ids = self._table.lookup(values)
buckets = math_ops.add(buckets, self._table.size())
is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
ids = array_ops.where(is_id_non_default, ids, buckets, name=scope)
else:
ids = buckets
return ids
上面的 lookup 操作为核心代码,大致流程为:先将当前待查找的所有 keys 通过 str_to_hash_bucket 这个 hash 函数映射成 int 类型的 bucket 号,然后为了防止和已经在表中的那些 index 冲突,需要统一加上 table.size()。然后在 table 中查找所有 keys 的 index,如果在 id table 表中会返回正确的 index,如果不在就返回默认值。然后将那些查找后等于默认值的 keys 的 index 换成通过 hash 函数得到的 bucket 号。这样存在的性能问题是当 id table 表较大时,会有很多时间浪费在没必要的 hash 操作上。所以一种优化方案是:先查表得到所有 keys 的 index,然后只 hash 那些映射值等于默认值的元素。
我自己基于这个思路实现了一版,暂且就称为 lookup_v2 吧,但是虽然这样不用为所有的 keys 做 hash,但也增加了几个操作,所以性能还不能保证。后续会考虑直接优化 c++层面的代码。下面是核心代码:
def lookup_v2(self, keys, name=None):
values = keys
if self._num_oov_buckets == 0:
ids = self._table.lookup(values, name=name)
else:
with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
self._hasher_spec)
if self._table:
init_ids = self._table.lookup(values)
is_id_default_idx = array_ops.where(math_ops.equal(init_ids, self._table.default_value))
hash_values = array_ops.gather(values, array_ops.squeeze(is_id_default_idx))
default_buckets = str_to_hash_bucket(
_as_string(hash_values),
num_buckets=self._num_oov_buckets,
name="hash_bucket")
default_buckets = math_ops.add(default_buckets, self._table.size())
default_buckets = control_flow_ops.cond(gen_math_ops.equal(array_ops.size(default_buckets), 1),
lambda: array_ops.expand_dims(default_buckets, axis=0),
lambda: default_buckets)
ids = gen_array_ops.tensor_scatter_update(init_ids, is_id_default_idx, default_buckets)
else:
ids = str_to_hash_bucket(_as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket")
return ids
参考文献:
https://mathmach.com/2442aa9e/
本文转载自 Alex-zhai 知乎账号。
原文链接:https://zhuanlan.zhihu.com/p/93116229
更多内容推荐
MySQL 系列教程之(三) MySQL 基本概念和操作
MySQL 系列教程之(三) MySQL 基本概念和操作
2021-08-14
elasticsearch 实战三部曲之二:文档操作,java 基础填空题
"_id":"999"
2021-10-31
我从 LongAdder 中窥探到了高并发的秘籍,上面只写了两个字...
LongAdder 是怎么解决多线程操作热点 value 导致并发修改冲突很大这个问题的? 为什么高并发场景下 LongAdder 的 sum 方法不能返回一个准确的值? 为什么高并发场景下 LongAdder 的写性能比 AtomicLong 高?
2020-07-16
32|存储引擎:数据清洗与存储
这节课,我们一起写一个存储引擎,用它来处理数据的存储问题。
2022-12-22
如何通过哈希查找 JS 对象内存地址?
通过这一讲,你也能更好地理解哈希、散列表、字典,这些初学者都比较容易混淆的概念……
2022-10-22
五、HikariCP 源码分析之初始化分析二
HikariCP一直以高效著称,但是从来没有去研究过为什么会比其他的数据库连接池高效。后来为了排查一个数据库连接池的问题,就深入了解了一下HikariCP的源代码,然后就有了这个深入浅出的源码解析系列,不仅解释是什么,还让你知道为什么。
2022-07-29
如何在 iOS 中解决循环引用的问题
我们经常会遇到循环引用的问题
上古时代 Objective-C 中哈希表的实现
文章会介绍上古时代 Objective-C 哈希表
spring4.1.8 初始化源码学习三部曲之三:AbstractApplicationContext.refresh 方法
《spring4.1.8初始化源码学习三部曲》系列的终篇,重点是学习AbstractApplicationContext类的refresh()方法
2022-06-09
23|OpenClip:让我们搞清楚图片说了些什么
OpenClip:让我们搞清楚图片说了些什么?
2023-04-28
第四范式 OpenMLDB: 拓展 Spark 源码实现高性能 Join
OpenMLDB是针对AI场景优化的开源数据库项目,实现了数据与计算一致性的离线MPP场景和在线OLTP场景计算引擎。
字典的函数操作
2022-12-29
04|新时代模型性能大比拼,GPT-3 到底胜在哪里?
这一讲我们一起使用 Fasttext、T5-small 和 T5-base 这三个预训练模型,做零样本分类测试。
2023-03-27
TensorFlow Ranking 框架在海外推荐业务中的实践与应用
爱奇艺海外推荐业务引入TensorFlow Ranking(TFR)框架,并在此基础上进行了研究和改进,显著提升了推荐效果。本文将分享TFR框架在海外推荐业务中的实践和应用。
HBase Bulkload 实践探讨
本文来自《2019年有赞技术大礼包》系列。
19|协同过滤:召回算法中永远不落幕的经典
在前面的章节中,我们讲解了数据、算法以及简单的推荐服务,从本章开始,我们将开启一个全新的篇章:算法。
2023-05-29
推理性能提升一倍,TensorFlow Feature Column 性能优化实践
在CTR(Click Through Rate)点击率预估的推荐算法场景,TensorFlow Feature Column被广泛应用到实践中。
leetcode 769. Max Chunks To Make Sorted 最多能完成排序的块 (中等)
思路:从左往右遍历,同时记录当前的最大值,每当当前最大值等于数组位置时,我们可以多一次分割。
2022-08-06
微信扫码登录技术实现的简单思考
微信扫码登录是经常用到的的骚操作,但是,其实现的思路是怎样的,可能很多人都没有去思考过。
2021-03-26
推荐阅读
c++11 分边在两个 map 中执行相同操作,代码如何优化
2023-04-27
C++ 学习 ---cstdio 的源码学习分析 04- 创建临时文件函数 tmpfile
2022-09-20
11. FlinkSQL 的自定义 UDAF 函数
2023-09-08
数据库内核杂谈(三十六)- 向量数据库(4)quantization 和 HNSW
数据库3. 复杂查询:JOIN - USING 用法
2023-09-26
【内存操作函数内功修炼】memcpy + memmove + memcmp + memset(四)
2022-09-21
22|YouTubeDNN:召回算法的后起之秀(下)
2023-06-05
电子书
大厂实战PPT下载
换一换 张鑫 | 微软亚洲研究院 研发工程师
宋恺涛 | 微软亚洲研究院 高级研究员
邓波 | 淘天集团 高级技术专家
评论