TFRecordsからSparseTensorを読み込んでbatch化

はじめに

TensorFlow 0.11 で

  • TFRecordsからデータを読み込む
  • 読み込むのはSparseTensor
  • batch化もする

というケースの一例。

前提

素性が

feature_vector = [2.1, 0.0, 0.0, 4.9, 0.0, 1.5, 0.0, 0.0]

のような形をしていて、これをスパース表現したものとintのラベルのセットで1レコードが構成されているようなデータがあるとする。

さらにこれをTFRecordsに書き出したものがあるとして、これを読み込むところの話。

素性を書き出すときはindexのリスト[0, 3, 5]とvalueのリスト[2.1, 4.9, 1.5]を分けて格納することになる。と思う。これについては書かない。

読み込み

普通のTensorならparse_single_exampleしてからbatchを作るが、 SparseTensorの場合、batchを作ってからparse_exampleしなければならない。

例えばこんな感じ。

filenames = [filename]  # 今回は1ファイルだけ
filename_queue = tf.train.string_input_producer(filenames)
reader = tf.TFRecordReader()
# 1行read
_, serialized_example = reader.read(filename_queue)
# batchを作る
serialized_examples = tf.train.batch(
    [serialized_example], batch_size=batch_size)
# parse
features = tf.parse_example(
    serialized_examples,
    features={
        'ind': tf.VarLenFeature(tf.int64),
        'val': tf.VarLenFeature(tf.float64),
        'label': tf.FixedLenFeature([1], tf.int64)
    }
)

読み込んだ結果

f1 = [2.1, 0.0, 0.0, 4.9, 0.0, 1.5, 0.0, 0.0]
f2 = [0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 6.0]

の2レコードを読み込んでbatch_size = 2でbatch化すると、

features['ind'].indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]]
features['ind'].values = [0, 3, 5, 2, 7]
features['val'].indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]]  # indと同じ
features['val'].values = [2.1, 4.9, 1.5, 0.8, 6.0]

つまり

ind[0, 0] = 0, val[0, 0] = 2.1
ind[0, 1] = 3, val[0, 1] = 4.9
ind[0, 2] = 5, val[0, 2] = 1.5
ind[1, 0] = 2, val[1, 0] = 0.8
ind[1, 1] = 7, val[1, 1] = 6.0

という形になる。

ただしレコード内での素性の格納順はTFRecords書き出し時の順番に依存するので、

features['ind'].indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]]
features['ind'].values = [5, 0, 3, 7, 2]
features['val'].indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]]  # indと同じ
features['val'].values = [1.5, 2.1, 4.9, 6.0, 0.8]

のようになることもある。上記のように1ファイルのみの使用かつtf.train.shuffle_batchを使わない場合、レコードをまたいで入れ替わることはない。

ちなみに

そもそもほしいのは

f.indices = [[0, 0], [0, 3], [0, 5], [1, 2], [1, 7]]
f.values = [2.1, 4.9, 1.5, 0.8, 6.0]

なわけなので、上記の方法で得たindvalからこれを作る必要がある。

tf.reshapeなどを使えば普通にできるのだが、経験上tf.reshapeは結構遅い。 後の処理がindvalのままでできるなら、それに越したことはない。 このへんは別途書くかも。

そもそもparseした時点で上記のfの形になっていてくれれば苦労はないわけだが、その方法はわからない。