Custom Data Readers

优质
小牛编辑
134浏览
2023-12-01

PREREQUISITES:

We divide the task of supporting a file format into two pieces:

  • File formats: We use a Reader Op to read a record (which can be any string) from a file.
  • Record formats: We use decoder or parsing Ops to turn a string record into tensors usable by TensorFlow.

For example, to read a CSV file, we use a Reader for text files followed by an Op that parses CSV data from a line of text.

Contents

Custom Data Readers

  • Writing a Reader for a file format
  • Writing an Op for a record format

Writing a Reader for a file format

A Reader is something that reads records from a file. There are some examples of Reader Ops already built into TensorFlow:

You can see these all expose the same interface, the only differences are in their constructors. The most important method is read. It takes a queue argument, which is where it gets filenames to read from whenever it needs one (e.g. when the read op first runs, or the previous read reads the last record from a file). It produces two scalar tensors: a string key and and a string value.

To create a new reader called SomeReader, you will need to:

  1. In C++, define a subclass of tensorflow::ReaderBase called SomeReader.
  2. In C++, register a new reader op and kernel with the name "SomeReader".
  3. In Python, define a subclass of tf.ReaderBase called SomeReader.

You can put all the C++ code in a file in tensorflow/core/user_ops/some_reader_op.cc. The code to read a file will live in a descendant of the C++ ReaderBase class, which is defined in tensorflow/core/kernels/reader_base.h. You will need to implement the following methods:

  • OnWorkStartedLocked: open the next file
  • ReadLocked: read a record or report EOF/error
  • OnWorkFinishedLocked: close the current file, and
  • ResetLocked: get a clean slate after, e.g., an error

These methods have names ending in "Locked" since ReaderBase makes sure to acquire a mutex before calling any of these methods, so you generally don't have to worry about thread safety (though that only protects the members of the class, not global state).

For OnWorkStartedLocked, the name of the file to open is the value returned by the current_work() method. ReadLocked has this signature:

Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)

If ReadLocked successfully reads a record from the file, it should fill in:

  • *key: with an identifier for the record, that a human could use to find this record again. You can include the filename from current_work(), and append a record number or whatever.
  • *value: with the contents of the record.
  • *produced: set to true.

If you hit the end of a file (EOF), set *at_end to true. In either case, return Status::OK(). If there is an error, simply return it using one of the helper functions from tensorflow/core/lib/core/errors.h without modifying any arguments.

Next you will create the actual Reader op. It will help if you are familiar with the adding an op how-to. The main steps are:

  • Registering the op.
  • Define and register an OpKernel.

To register the op, you will use a REGISTER_OP call defined in tensorflow/core/framework/op.h. Reader ops never take any input and always have a single output with type Ref(string). They should always call SetIsStateful(), and have a string container and shared_name attrs. You may optionally define additional attrs for configuration or include documentation in a Doc. For examples, see tensorflow/core/ops/io_ops.cc, e.g.:

#include "tensorflow/core/framework/op.h"

REGISTER_OP("TextLineReader")
    .Output("reader_handle: Ref(string)")
    .Attr("skip_header_lines: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc");

To define an OpKernel, Readers can use the shortcut of descending from ReaderOpKernel, defined in tensorflow/core/framework/reader_op_kernel.h, and implement a constructor that calls SetReaderFactory. After defining your class, you will need to register it using REGISTER_KERNEL_BUILDER(...). An example with no attrs:

#include "tensorflow/core/framework/reader_op_kernel.h"

class TFRecordReaderOp : public ReaderOpKernel {
 public:
  explicit TFRecordReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    Env* env = context->env();
    SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
  }
};

REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
                        TFRecordReaderOp);

An example with attrs:

#include "tensorflow/core/framework/reader_op_kernel.h"

class TextLineReaderOp : public ReaderOpKernel {
 public:
  explicit TextLineReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    int skip_header_lines = -1;
    OP_REQUIRES_OK(context,
                   context->GetAttr("skip_header_lines", &skip_header_lines));
    OP_REQUIRES(context, skip_header_lines >= 0,
                errors::InvalidArgument("skip_header_lines must be >= 0 not ",
                                        skip_header_lines));
    Env* env = context->env();
    SetReaderFactory([this, skip_header_lines, env]() {
      return new TextLineReader(name(), skip_header_lines, env);
    });
  }
};

REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
                        TextLineReaderOp);

The last step is to add the Python wrapper. You will import tensorflow.python.ops.io_ops in tensorflow/python/user_ops/user_ops.py and add a descendant of io_ops.ReaderBase.

from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops

class SomeReader(io_ops.ReaderBase):

    def __init__(self, name=None):
        rr = gen_user_ops.some_reader(name=name)
        super(SomeReader, self).__init__(rr)


ops.NoGradient("SomeReader")
ops.RegisterShape("SomeReader")(common_shapes.scalar_shape)

You can see some examples in tensorflow/python/ops/io_ops.py.

Writing an Op for a record format

Generally this is an ordinary op that takes a scalar string record as input, and so follow the instructions to add an Op. You may optionally take a scalar string key as input, and include that in error messages reporting improperly formatted data. That way users can more easily track down where the bad data came from.

Examples of Ops useful for decoding records:

Note that it can be useful to use multiple Ops to decode a particular record format. For example, you may have an image saved as a string in a tf.train.Example protocol buffer. Depending on the format of that image, you might take the corresponding output from a tf.parse_single_example op and call tf.decode_jpeg, tf.decode_png, or tf.decode_raw. It is common to take the output of tf.decode_raw and use tf.slice and tf.reshape to extract pieces.