Custom Data Readers
PREREQUISITES:
- Some familiarity with C++.
- Must have downloaded TensorFlow source, and be able to build it.
We divide the task of supporting a file format into two pieces:
- File formats: We use a Reader Op to read a record (which can be any string) from a file.
- Record formats: We use decoder or parsing Ops to turn a string record into tensors usable by TensorFlow.
For example, to read a CSV file, we use a Reader for text files followed by an Op that parses CSV data from a line of text.
Contents
Custom Data Readers
- Writing a Reader for a file format
- Writing an Op for a record format
Writing a Reader for a file format
A Reader
is something that reads records from a file. There are some examples of Reader Ops already built into TensorFlow:
tf.TFRecordReader
(source inkernels/tf_record_reader_op.cc
)tf.FixedLengthRecordReader
(source inkernels/fixed_length_record_reader_op.cc
)tf.TextLineReader
(source inkernels/text_line_reader_op.cc
)
You can see these all expose the same interface, the only differences are in their constructors. The most important method is read
. It takes a queue argument, which is where it gets filenames to read from whenever it needs one (e.g. when the read
op first runs, or the previous read
reads the last record from a file). It produces two scalar tensors: a string key and and a string value.
To create a new reader called SomeReader
, you will need to:
- In C++, define a subclass of
tensorflow::ReaderBase
calledSomeReader
. - In C++, register a new reader op and kernel with the name
"SomeReader"
. - In Python, define a subclass of
tf.ReaderBase
calledSomeReader
.
You can put all the C++ code in a file in tensorflow/core/user_ops/some_reader_op.cc
. The code to read a file will live in a descendant of the C++ ReaderBase
class, which is defined in tensorflow/core/kernels/reader_base.h
. You will need to implement the following methods:
OnWorkStartedLocked
: open the next fileReadLocked
: read a record or report EOF/errorOnWorkFinishedLocked
: close the current file, andResetLocked
: get a clean slate after, e.g., an error
These methods have names ending in "Locked" since ReaderBase
makes sure to acquire a mutex before calling any of these methods, so you generally don't have to worry about thread safety (though that only protects the members of the class, not global state).
For OnWorkStartedLocked
, the name of the file to open is the value returned by the current_work()
method. ReadLocked
has this signature:
Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
If ReadLocked
successfully reads a record from the file, it should fill in:
*key
: with an identifier for the record, that a human could use to find this record again. You can include the filename fromcurrent_work()
, and append a record number or whatever.*value
: with the contents of the record.*produced
: set totrue
.
If you hit the end of a file (EOF), set *at_end
to true
. In either case, return Status::OK()
. If there is an error, simply return it using one of the helper functions from tensorflow/core/lib/core/errors.h
without modifying any arguments.
Next you will create the actual Reader op. It will help if you are familiar with the adding an op how-to. The main steps are:
- Registering the op.
- Define and register an
OpKernel
.
To register the op, you will use a REGISTER_OP
call defined in tensorflow/core/framework/op.h
. Reader ops never take any input and always have a single output with type Ref(string)
. They should always call SetIsStateful()
, and have a string container
and shared_name
attrs. You may optionally define additional attrs for configuration or include documentation in a Doc
. For examples, see tensorflow/core/ops/io_ops.cc
, e.g.:
#include "tensorflow/core/framework/op.h"
REGISTER_OP("TextLineReader")
.Output("reader_handle: Ref(string)")
.Attr("skip_header_lines: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc");
To define an OpKernel
, Readers can use the shortcut of descending from ReaderOpKernel
, defined in tensorflow/core/framework/reader_op_kernel.h
, and implement a constructor that calls SetReaderFactory
. After defining your class, you will need to register it using REGISTER_KERNEL_BUILDER(...)
. An example with no attrs:
#include "tensorflow/core/framework/reader_op_kernel.h"
class TFRecordReaderOp : public ReaderOpKernel {
public:
explicit TFRecordReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
Env* env = context->env();
SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
}
};
REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
TFRecordReaderOp);
An example with attrs:
#include "tensorflow/core/framework/reader_op_kernel.h"
class TextLineReaderOp : public ReaderOpKernel {
public:
explicit TextLineReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
int skip_header_lines = -1;
OP_REQUIRES_OK(context,
context->GetAttr("skip_header_lines", &skip_header_lines));
OP_REQUIRES(context, skip_header_lines >= 0,
errors::InvalidArgument("skip_header_lines must be >= 0 not ",
skip_header_lines));
Env* env = context->env();
SetReaderFactory([this, skip_header_lines, env]() {
return new TextLineReader(name(), skip_header_lines, env);
});
}
};
REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
TextLineReaderOp);
The last step is to add the Python wrapper. You will import tensorflow.python.ops.io_ops
in tensorflow/python/user_ops/user_ops.py
and add a descendant of io_ops.ReaderBase
.
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops
class SomeReader(io_ops.ReaderBase):
def __init__(self, name=None):
rr = gen_user_ops.some_reader(name=name)
super(SomeReader, self).__init__(rr)
ops.NoGradient("SomeReader")
ops.RegisterShape("SomeReader")(common_shapes.scalar_shape)
You can see some examples in tensorflow/python/ops/io_ops.py
.
Writing an Op for a record format
Generally this is an ordinary op that takes a scalar string record as input, and so follow the instructions to add an Op. You may optionally take a scalar string key as input, and include that in error messages reporting improperly formatted data. That way users can more easily track down where the bad data came from.
Examples of Ops useful for decoding records:
Note that it can be useful to use multiple Ops to decode a particular record format. For example, you may have an image saved as a string in a tf.train.Example
protocol buffer. Depending on the format of that image, you might take the corresponding output from a tf.parse_single_example
op and call tf.decode_jpeg
, tf.decode_png
, or tf.decode_raw
. It is common to take the output of tf.decode_raw
and use tf.slice
and tf.reshape
to extract pieces.