写点什么

分布式 tensorflow 源码解读 3:lookup.index_table_from_tensor

  • 2019-11-28
  • 本文字数:9943 字

    阅读完需:约 33 分钟

分布式tensorflow源码解读3:lookup.index_table_from_tensor

背景

推荐排序 dnn 模型中经常会用一些特别稀疏的 id 特征,此时需要对这些 id 特征做 embedding 操作,一般在 tensorflow 中都会使用一个 shape=[id_index_size, embedding_size]的 Variable 矩阵做 embedding 参数,然后根据 id 特征的 index 去 Variable 矩阵中查表得到相应的 embedding 表示。这里需要注意的是:id_index_size 的大小一般都不会等于对应 id table 的元素个数,因为有很多 id 元素不在原始的 id table 表中,比如新上架的一些商品等。此时需要将 id_index_size 设置的大一些,以留一些位置给那些不在 id table 表的元素使用。那么从原始的 id 特征值映射成[0, id_index_size)间的 index 是怎么做的呢?


如果 id 量非常小的话,可在特征提取后把 id 排序一遍生成从 0 开始的连续 id 值,但在工业界的场景下 id 特征的量级往往是百万到上亿级别,很难做排序。幸好 tensorFlow 内部有一个函数可以将原始的 id 特征值映射到从 0 开始的 index,这个函数就是 lookup.index_table_from_tensor。

例子

首先举一个 index_table_from_tensor 的使用例子:


import tensorflow as tf
sess = tf.Session()vocabulary_list = tf.constant(["emerson", "lake", "palmer"])table = tf.contrib.lookup.index_table_from_tensor( vocabulary_list=vocabulary_list, num_oov_buckets=10, default_value=-1)features = tf.constant(["emerson", "lake", "and", "palmer", "dad", "mom", "hello"])table.init.run(session=sess)ids = table.lookup(features)print(sess.run(ids))
返回值:[ 0 1 10 2 9 3 9]
复制代码


index_table_from_tensor 函数的作用是:如果在 vocabulary_list 里面的 id 特征值,则映射为从 0 到 len(table)-1 之间的 index,如果不在 vocabulary_list 里面,则通过 hash 函数映射到 len(table)至 len(table) + num_oov_buckets - 1 之间的区间中。


我们开始分析源代码,首先看下 index_table_from_tensor 的源码:


def index_table_from_tensor(vocabulary_list,                            num_oov_buckets=0,                            default_value=-1,                            hasher_spec=FastHashSpec,                            dtype=dtypes.string,                            name=None):  """Returns a lookup table that converts a string tensor into int64 IDs.  This operation constructs a lookup table to convert tensor of strings into  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D  tensor where each element is a key and corresponding index within the tensor  is the value.
Args: vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to indices. The type of this object must be castable to `dtype`. num_oov_buckets: The number of out-of-vocabulary buckets. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignment of out-of-vocabulary buckets. dtype: The type of values passed to `lookup`. Only string and integers are supported. name: A name for this op (optional). Returns: The lookup table to map an input `Tensor` to index `int64` `Tensor`. Raises: ValueError: If `vocabulary_list` is invalid. ValueError: If `num_oov_buckets` is negative. """
“”“ 类型检查代码 ”“”
with ops.name_scope(name, "string_to_index") as feat_to_id_scope: # 获取由传入的vocabulary_list构造的hash表,keys是vocabulary_list里面的元素,values是 # 【0, len(num_elements)-1】 keys = ops.convert_to_tensor(vocabulary_list) num_elements = array_ops.size(keys) values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = "" with ops.name_scope(None, "hash_table") as hash_table_scope: table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys init = KeyValueTensorInitializer( table_keys, values, table_keys.dtype.base_dtype, dtypes.int64, name="table_init") table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) if num_oov_buckets: table = IdTableWithHashBuckets( table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, name=feat_to_id_scope, key_dtype=dtype) return table
复制代码


从上面的代码片段可以看出,index_table_from_tensor 函数体中有三个比较重要的步骤:


根据 keys 和 values 创建 KeyValueTensorInitializer。


根据 KeyValueTensorInitializer 创建具体的 HashTable 类,对于不在 vocabulary_list 表里面的元素,统一返回默认值。


如果 num_oov_buckets>0,则会创建带有 Buckets 的 IdTableWithHashBuckets 类,用一个 hash 函数返回那些不在 vocabulary_list 表里面的 id 元素的 index 值。


首先我们来分析第一步的操作,先通过传入的 vocabulary_list 去构造 keys 和 values,然后传入到 KeyValueTensorInitializer 类中返回 table 初始化的 op,注意此时只是返回一个用于 table 初始化的 OP,还没有做具体的初始化工作(初始化要放到具体的 HashTable 中去完成)。下面看下 KeyValueTensorInitializer 的代码:


class KeyValueTensorInitializer(TableInitializerBase):  """Table initializers given `keys` and `values` tensors."""
def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): with ops.name_scope(name, "key_value_init", [keys, values]) as scope: self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") self._values = ops.convert_to_tensor( values, dtype=value_dtype, name="values") self._name = scope
super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, self._values.dtype)
def initialize(self, table): """Initializes the given `table` with `keys` and `values` tensors. Args: table: The table to initialize. Returns: The operation that initializes the table. Raises: TypeError: when the keys and values data types do not match the table key and value data types. """ _check_table_dtypes(table, self._keys.dtype, self._values.dtype) with ops.name_scope( self._name, values=(table.table_ref, self._keys, self._values)) as scope: init_op = gen_lookup_ops.lookup_table_import_v2( table.table_ref, self._keys, self._values, name=scope) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op可看出table的初始化操作是通过lookup_table_import_v2这个op来完成,下面就分析下这个op的相关c代码:
REGISTER_OP("LookupTableImportV2") .Input("table_handle: resource") .Input("keys: Tin") .Input("values: Tout") .Attr("Tin: type") .Attr("Tout: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); return Status::OK(); });
// Clear the table and insert data.class LookupTableImportOp : public OpKernel { public: explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ .... ”“”
const Tensor& keys = ctx->input(1); const Tensor& values = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
int memory_used_before = 0; if (ctx->track_allocations()) { memory_used_before = table->MemoryUsed(); } OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); if (ctx->track_allocations()) { ctx->record_persistent_memory_allocation(table->MemoryUsed() - memory_used_before); } }};
复制代码


这里通过 table_handle 句柄来传入具体的 Table 类,上面举的例子中使用的是 HashTable 类,HashTable 类继承于 InitializableLookupTable 类,InitializableLookupTable 可理解为那些需要初始化操作的 table 集合,而 InitializableLookupTable 又继承于 LookupInterface 类。通过 table->ImportValues()方法来将 vocabulary_list 中构造的 keys 和 values 导入到 hash 表中,完成 hash 表的初始化工作。


通过 LookupTableImportV2 返回 init_op 之后,将该初始化 hash 表的 op 传给 HashTable 类,HashTable 类和其父类的代码如下:


class HashTable(InitializableLookupTableBase):
def __init__(self, initializer, default_value, shared_name=None, name=None): """Creates a non-initialized `HashTable` object. Creates a table, the type of its keys and values are specified by the initializer. Before using the table you will have to initialize it. After initialization the table will be immutable. Args: initializer: The table initializer to use. See `HashTable` kernel for supported key and value types. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). Returns: A `HashTable` object. """ with ops.name_scope(name, "hash_table", (initializer, default_value)) as scope: table_ref = gen_lookup_ops.hash_table_v2( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, name=scope)
super(HashTable, self).__init__(table_ref, default_value, initializer) self._value_shape = self._default_value.get_shape()

class InitializableLookupTableBase(LookupInterface): """Initializable lookup table interface. An initializable lookup tables persist across different steps. """
def __init__(self, table_ref, default_value, initializer): """Construct a table object from a table reference. If requires a table initializer object (subclass of `TableInitializerBase`). It provides the table key and value types, as well as the op to initialize the table. The caller is responsible to execute the initialization op. Args: table_ref: The table reference, i.e. the output of the lookup table ops. default_value: The value to use if a key is missing in the table. initializer: The table initializer to use. """ name = table_ref.op.name.split("/")[-1] super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, initializer.value_dtype, name) self._table_ref = table_ref self._default_value = ops.convert_to_tensor( default_value, dtype=self._value_dtype) self._default_value.get_shape().merge_with(tensor_shape.scalar()) self._init = initializer.initialize(self)
@property def table_ref(self): """Get the underlying table reference.""" return self._table_ref
@property def default_value(self): """The default value of the table.""" return self._default_value
@property def init(self): """The table initialization op.""" return self._init
def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. The `default_value` is used for keys not present in the table. """ key_tensor = keys with ops.name_scope(name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, key_tensor, self._default_value, name=scope)
values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) else: return values
复制代码


InitializableLookupTableBase 类继承于 HashTable,InitializableLookupTableBase 类中有 init 操作和 look_up,所以表的初始化和 look_up 操作其实都是在父类中完成,其中表的初始化操作上面已经讨论过,我们就重点看下 look_up 操作的相关 OP:



lookup_table_find_v2
// Table lookup op. Perform the lookup operation on the given table.class LookupTableFindOp : public OpKernel { public: explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ ...... ”“”
const Tensor& key = ctx->input(1); const Tensor& default_value = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
TensorShape output_shape = key.shape(); output_shape.RemoveLastDims(table->key_shape().dims()); output_shape.AppendShape(table->value_shape()); Tensor* out; OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value)); }};
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), LookupTableFindOp);REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU), LookupTableFindOp);
# table->Find调用的是HashTable类的DoFind方法Status DoFind(const Tensor& key, Tensor* value, const Tensor& default_value) override { const V default_val = default_value.flat<V>()(0); const auto key_values = key.flat<K>(); auto value_values = value->flat<V>();
for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); } return Status::OK(); }

复制代码


最后我们分析下 IdTableWithHashBuckets 类,该类可用一个 hash 函数返回那些不在 vocabulary_list 表里 id 元素的 index 值,先举个例子:


import tensorflow as tfnum_oov_buckets = 3input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])table = tf.IdTableWithHashBuckets(      tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),      num_oov_buckets)out = table.lookup(input_tensor).table.init.run()print(out.eval())
#filename依次为"emerson", "lake", "palmer"
复制代码


#结果为[0, 1, 2, 4, 7]


下面看下 IdTableWithHashBuckets 的源码:


class IdTableWithHashBuckets(LookupInterface):  """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.  For example, if an instance of `IdTableWithHashBuckets` is initialized with a  string-to-id table that maps:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  The `IdTableWithHashBuckets` object will performs the following mapping:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  * `<other term> -> bucket_id`, where bucket_id will be between `3` and  `3 + num_oov_buckets - 1`, calculated by:  `hash(<term>) % num_oov_buckets + vocab_size`  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,  the lookup result is `[0, 1, 2, 4, 7]`.  If `table` is None, only out-of-vocabulary buckets are used.  """
def __init__(self, table, num_oov_buckets, hasher_spec=FastHashSpec, name=None, key_dtype=None): """Construct a `IdTableWithHashBuckets` object. Args: table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets (optional). name: A name for the operation (optional). key_dtype: Data type of keys passed to `lookup`. Defaults to `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must be string or integer, and must be castable to `table.key_dtype`. Raises: ValueError: when `table` in None and `num_oov_buckets` is not positive. TypeError: when `hasher_spec` is invalid. """ # If a name ends with a '/' it is a "name scope", remove all trailing '/' # characters to use as table name. “”“ 类型检查 “”“
self._num_oov_buckets = num_oov_buckets super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64, name.split("/")[-1])
@property def init(self): """The table initialization op.""" if self._table: return self._table.init with ops.name_scope(None, "init"): return control_flow_ops.no_op()
@property def table_ref(self): if self._table is not None: return self._table.table_ref return None
def _get_string_to_hash_bucket_fn(self, hasher_spec): """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" if not isinstance(hasher_spec, HasherSpec): raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec) if hasher_spec.hasher == "fasthash": return string_ops.string_to_hash_bucket_fast if hasher_spec.hasher == "legacy": return string_ops.string_to_hash_bucket if hasher_spec.hasher == "stronghash": return functools.partial( string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
def lookup(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: # TODO(yleon): Consider moving this functionality to its own kernel. with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) buckets = str_to_hash_bucket( _as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") if self._table: ids = self._table.lookup(values) buckets = math_ops.add(buckets, self._table.size()) is_id_non_default = math_ops.not_equal(ids, self._table.default_value) ids = array_ops.where(is_id_non_default, ids, buckets, name=scope) else: ids = buckets return ids
复制代码


上面的 lookup 操作为核心代码,大致流程为:先将当前待查找的所有 keys 通过 str_to_hash_bucket 这个 hash 函数映射成 int 类型的 bucket 号,然后为了防止和已经在表中的那些 index 冲突,需要统一加上 table.size()。然后在 table 中查找所有 keys 的 index,如果在 id table 表中会返回正确的 index,如果不在就返回默认值。然后将那些查找后等于默认值的 keys 的 index 换成通过 hash 函数得到的 bucket 号。这样存在的性能问题是当 id table 表较大时,会有很多时间浪费在没必要的 hash 操作上。所以一种优化方案是:先查表得到所有 keys 的 index,然后只 hash 那些映射值等于默认值的元素。


我自己基于这个思路实现了一版,暂且就称为 lookup_v2 吧,但是虽然这样不用为所有的 keys 做 hash,但也增加了几个操作,所以性能还不能保证。后续会考虑直接优化 c++层面的代码。下面是核心代码:


 def lookup_v2(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) if self._table: init_ids = self._table.lookup(values) is_id_default_idx = array_ops.where(math_ops.equal(init_ids, self._table.default_value)) hash_values = array_ops.gather(values, array_ops.squeeze(is_id_default_idx)) default_buckets = str_to_hash_bucket( _as_string(hash_values), num_buckets=self._num_oov_buckets, name="hash_bucket") default_buckets = math_ops.add(default_buckets, self._table.size()) default_buckets = control_flow_ops.cond(gen_math_ops.equal(array_ops.size(default_buckets), 1), lambda: array_ops.expand_dims(default_buckets, axis=0), lambda: default_buckets) ids = gen_array_ops.tensor_scatter_update(init_ids, is_id_default_idx, default_buckets) else: ids = str_to_hash_bucket(_as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") return ids
复制代码


参考文献:


https://mathmach.com/2442aa9e/


https://github.com/tensorflow/tensorflow/blob/5b900cfe4b3b848f577315a0dde09a729f770e95/tensorflow/python/ops/lookup_ops.py#L1042


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/93116229


2019-11-28 08:003296

评论

发布
暂无评论
发现更多内容

逃离过度努力陷阱

FunTester

FunTester 湛卢 轻松主义

RDP是什么意思?有什么用?

行云管家

运维 网络协议 RDP

大咖说|阿里巴巴副总裁陈龙:数字技术将在绿色低碳转型中发挥关键作用

大咖说

阿里巴巴 数字化 碳中和

多个角度论证SeekTiger 生态核心STI的魅力

西柚子

crontab命令详细介绍教程,快来围观

CRMEB

大数据培训spark SQL中count(*)和count(1)源码分析

@零度

大数据开发 spark SQL

怒肝 JavaScript 数据结构 — 栈篇(三)

杨成功

数据结构 4月月更

netty系列之:netty中的自动解码器ReplayingDecoder

程序那些事

Java Netty 程序那些事 4月月更

Android C++系列:JNI常见问题

轻口味

c++ android 4月月更

三高Mysql - 搭建“三高”架构之复制

懒时小窝

MySQL MySQL 高可用

Tapdata PDK 生态共建计划启动!MongoDB、Doris、OceanBase、PolarDB等十余家厂商首批加入

MongoDB中文社区

三高Mysql - 搭建“三高”架构之扩展与切换

懒时小窝

MySQL MySQL 高可用

直播预告|年营业额百亿的企业都在如何做数字化转型

云智慧AIOps社区

数字化转型 AIOPS 解决方案 智能运维

ETL 和数仓建模的设计思路!

五分钟学大数据

4月月更

重磅!百度安全参编的国家标准《信息安全技术 术语》正式发布

百度开发者中心

俄乌战争下的国产数据库替换思考-墨天轮

墨天轮

数据库 oracle 达梦 gbase8a

盘点近期虎符交易所上线的项目

区块链前沿News

虎符交易所

“囤菜新宠”预制菜,会是生鲜电商的破局点吗?

易观分析

ETL调度软件TASKCTL核心调度节点安装

敏捷调度TASKCTL

kettle 调度引擎 ETL 任务队列 调度任务

@所有高校师生,2022全国大学生物联网设计竞赛火热开启,限量礼品等你来拿!

HarmonyOS开发者

HarmonyOS 物联网设计竞赛

eBPF Cilium实战(2) - 底层网络可观测性

北京好雨科技有限公司

Docker Kubernetes PaaS cilium

PHP项目微信提现功能代码详解

CRMEB

无需编程,基于微软mssql数据库零代码生成CRUD增删改查RESTful API接口

crudapi

低代码 API crud crudapi 增删改查

不再单调!快来自定义你的专属背景~

优麒麟

Linux 开源 操作系统 优麒麟 用户登录

如何通过Password Vault的XSS漏洞窃取用户密码信息

喀拉峻

网络安全 XSS

杭州等保测评公司有哪些?分别叫什么?如何能查到?

行云管家

等保 等级保护 等保测评 杭州

百度荣获 “2021年中国网络安全产业联盟数据安全工作委员会突出贡献奖”

百度开发者中心

零信任访问控制下企业ABAC的实施问题

Geek_2d6073

H5营销有什么优势?企业需要定制开发H5吗?

源字节1号

前端开发 后端开发 H5制作

云效研发效能度量体系,如何展示和解读交付效能数据

阿里云云效

阿里云 运维 研发管理 研发效能 研发团队

jackson学习之六:常用类注解

程序员欣宸

4月月更

分布式tensorflow源码解读3:lookup.index_table_from_tensor_文化 & 方法_Alex-zhai_InfoQ精选文章