Adding a New Op
PREREQUISITES:
- Some familiarity with C++.
- Must have downloaded TensorFlow source, and be able to build it.
If you'd like to incorporate an operation that isn't covered by the existing library, you can create a custom Op. To incorporate your custom Op, you'll need to:
- Register the new Op in a C++ file. The Op registration is independent of the implementation, and describes the semantics of how the Op is invoked. For example, it defines the Op name, and specifies its inputs and outputs.
- Implement the Op in C++. This implementation is called a "kernel", and there can be multiple kernels for different architectures (e.g. CPUs, GPUs) or input / output types.
- Create a Python wrapper. This wrapper is the public API to create the Op. A default wrapper is generated from the Op registration, which can be used directly or added to.
- Optionally, write a function to compute gradients for the Op.
- Optionally, write a function that describes the input and output shapes for the Op. This allows shape inference to work with your Op.
- Test the Op, typically in Python.
Contents
Adding a New Op
- Define the Op's interface
- Implement the kernel for the Op
- Generate the client wrapper
- The Python Op wrapper
- The C++ Op wrapper
- Verify it works
- Validation
- Op registration
- Attrs
- Attr types
- Polymorphism
- Inputs and Outputs
- Backwards compatibility
- GPU Support
- Implement the gradient in Python
- Implement a shape function in Python
Define the Op's interface
You define the interface of an Op by registering it with the TensorFlow system. In the registration, you specify the name of your Op, its inputs (types and names) and outputs (types and names), as well as docstrings and any attrs the Op might require.
To see how this works, suppose you'd like to create an Op that takes a tensor of int32
s and outputs a copy of the tensor, with all but the first element set to zero. Create file tensorflow/core/user_ops
/zero_out.cc
and add a call to the REGISTER_OP
macro that defines the interface for such an Op:
#include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32");
This ZeroOut
Op takes one tensor to_zero
of 32-bit integers as input, and outputs a tensor zeroed
of 32-bit integers.
A note on naming: The name of the Op should be unique and CamelCase. Names starting with an underscore (
_
) are reserved for internal use.
Implement the kernel for the Op
After you define the interface, provide one or more implementations of the Op. To create one of these kernels, create a class that extends OpKernel
and overrides the Compute
method. The Compute
method provides one context
argument of type OpKernelContext*
, from which you can access useful things like the input and output tensors.
Add your kernel to the file you created above. The kernel might look something like this:
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0) output(0) = input(0);
}
};
After implementing your kernel, you register it with the TensorFlow system. In the registration, you specify different constraints under which this kernel will run. For example, you might have one kernel made for CPUs, and a separate one for GPUs.
To do this for the ZeroOut
op, add the following to zero_out.cc
:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
Once you build and reinstall TensorFlow, the Tensorflow system can reference and use the Op when requested.
Generate the client wrapper
The Python Op wrapper
Python op wrappers are created automatically in bazel-genfiles/tensorflow/python/ops/gen_user_ops.py
for all ops placed in the tensorflow/core/user_ops
directory when you build Tensorflow. Those ops are imported into tensorflow/python/user_ops/user_ops.py
with the statement:
from tensorflow.python.ops.gen_user_ops import *
You may optionally use your own function instead. To do this, you first hide the generated code for that op by adding its name to the hidden
list in the "user_ops"
rule in tensorflow/python/BUILD
:
tf_gen_op_wrapper_py(
name = "user_ops",
hidden = [
"Fact",
],
require_shape_functions = False,
)
List your op next to "Fact"
. Next you add your replacement function to tensorflow/python/user_ops/user_ops.py
. Typically your function will call the generated function to actually add the op to the graph. The hidden version of the generated function will be in the gen_user_ops
package and start with an underscore ("_
"). For example:
def my_fact():
"""Example of overriding the generated code for an Op."""
return gen_user_ops._fact()
The C++ Op wrapper
C++ op wrappers are created automatically for all ops placed in the tensorflow/core/user_ops
directory, when you build Tensorflow. For example, ops in tensorflow/core/user_ops/zero_out.cc
will generate wrappers in bazel-genfiles/tensorflow/cc/ops/user_ops.{h,cc}
.
All generated wrappers for user ops are automatically imported into tensorflow/cc/ops/standard_ops.h
with the statement
#include "tensorflow/cc/ops/user_ops.h"
Verify it works
A good way to verify that you've successfully implemented your Op is to write a test for it. Create the file tensorflow/python/kernel_tests/zero_out_op_test.py
with the contents:
import tensorflow as tf
class ZeroOutTest(tf.test.TestCase):
def testZeroOut(self):
with self.test_session():
result = tf.user_ops.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
Then run your test:
$ bazel test tensorflow/python:zero_out_op_test
Validation
The example above assumed that the Op applied to a tensor of any shape. What if it only applied to vectors? That means adding a check to the above OpKernel implementation.
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
errors::InvalidArgument("ZeroOut expects a 1-D vector."));
// ...
}
This asserts that the input is a vector, and returns having set the InvalidArgument
status if it isn't. The OP_REQUIRES
macro takes three arguments:
- The
context
, which can either be anOpKernelContext
orOpKernelConstruction
pointer (seetensorflow/core/framework/op_kernel.h
), for itsSetStatus()
method. - The condition. For example, there are functions for validating the shape of a tensor in
tensorflow/core/public/tensor_shape.h
- The error itself, which is represented by a
Status
object, seetensorflow/core/public/status.h
. AStatus
has both a type (frequentlyInvalidArgument
, but see the list of types) and a message. Functions for constructing an error may be found intensorflow/core/lib/core/errors.h
.
Alternatively, if you want to test whether a Status
object returned from some function is an error, and if so return it, use OP_REQUIRES_OK
. Both of these macros return from the function on error.
Op registration
Attrs
Ops can have attrs, whose values are set when the Op is added to a graph. These are used to configure the Op, and their values can be accessed both within the kernel implementation and in the types of inputs and outputs in the Op registration. Prefer using an input instead of an attr when possible, since inputs are more flexible. They can change every step, be set using a feed, etc. Attrs are used for things that can't be done with inputs: any configuration that affects the signature (number or type of inputs or outputs) or that can't change from step-to-step.
You define an attr when you register the Op, by specifying its name and type using the Attr
method, which expects a spec of the form:
<name>: <attr-type-expr>
where <name>
begins with a letter and can be composed of alphanumeric characters and underscores, and <attr-type-expr>
is a type expression of the form described below
For example, if you'd like the ZeroOut
Op to preserve a user-specified index, instead of only the 0th element, you can register the Op like so:
REGISTER_OP("ZeroOut")
.Attr("preserve_index: int")
.Input("to_zero: int32")
.Output("zeroed: int32");
Your kernel can then access this attr in its constructor via the context
parameter:
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
// Get the index of the value to preserve
OP_REQUIRES_OK(context,
context->GetAttr("preserve_index", &preserve_index_));
// Check that preserve_index is positive
OP_REQUIRES(context, preserve_index >= 0,
errors::InvalidArgument("Need preserve_index >= 0, got ",
preserve_index));
}
void Compute(OpKernelContext* context) override {
// ...
}
private:
int preserve_index_;
};
which can then be used in the Compute
method:
void Compute(OpKernelContext* context) override {
// ...
// Check that preserveindex is in range
OP_REQUIRES(context, preserve_index < input.dimension(0),
errors::InvalidArgument("preserve_index out of range"));
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the requested input value
output_flat(preserve_index_) = input(preserve_index_);
}
To preserve backwards compatibility, you should specify a default value when adding an attr to an existing op:
REGISTER_OP("ZeroOut") .Attr("preserve_index: int = 0") .Input("to_zero: int32") .Output("zeroed: int32");
Attr types
The following types are supported in an attr:
string
: Any sequence of bytes (not required to be UTF8).int
: A signed integer.float
: A floating point number.bool
: True or false.type
: One of the (non-ref) values ofDataType
.shape
: ATensorShapeProto
.tensor
: ATensorProto
.list(<type>)
: A list of<type>
, where<type>
is one of the above types. Note thatlist(list(<type>))
is invalid.
See also: op_def_builder.cc:FinalizeAttr
for a definitive list.
Default values & constraints
Attrs may have default values, and some types of attrs can have constraints. To define an attr with constraints, you can use the following <attr-type-expr>
s:
{'<string1>', '<string2>'}
: The value must be a string that has either the value<string1>
or<string2>
. The name of the type,string
, is implied when you use this syntax. This emulates an enum:REGISTER_OP("EnumExample") .Attr("e: {'apple', 'orange'}");
{<type1>, <type2>}
: The value is of typetype
, and must be one of<type1>
or<type2>
, where<type1>
and<type2>
are supported tensor types. You don't specify that the type of the attr istype
. This is implied when you have a list of types in{...}
. For example, in this case the attrt
is a type that must be anint32
, afloat
, or abool
:REGISTER_OP("RestrictedTypeExample") .Attr("t: {int32, float, bool}");
There are shortcuts for common type constraints:
numbertype
: Typetype
restricted to the numeric (non-string and non-bool) types.realnumbertype
: Likenumbertype
without complex types.quantizedtype
: Likenumbertype
but just the quantized number types.
The specific lists of types allowed by these are defined by the functions (like
NumberTypes()
) intensorflow/core/framework/types.h
. In this example the attrt
must be one of the numeric types:REGISTER_OP("NumberType") .Attr("t: numbertype");
For this op:
tf.number_type(t=tf.int32) # Valid tf.number_type(t=tf.bool) # Invalid
int >= <n>
: The value must be an int whose value is greater than or equal to<n>
, where<n>
is a natural number.For example, the following Op registration specifies that the attr
a
must have a value that is at least2
:REGISTER_OP("MinIntExample") .Attr("a: int >= 2");
list(<type>) >= <n>
: A list of type<type>
whose length is greater than or equal to<n>
.For example, the following Op registration specifies that the attr
a
is a list of types (eitherint32
orfloat
), and that there must be at least 3 of them:REGISTER_OP("TypeListExample") .Attr("a: list({int32, float}) >= 3");
To set a default value for an attr (making it optional in the generated code), add = <default>
to the end, as in:
REGISTER_OP("AttrDefaultExample")
.Attr("i: int = 0");
The supported syntax of the default value is what would be used in the proto representation of the resulting GraphDef definition.
Here are examples for how to specify a default for all types:
REGISTER_OP("AttrDefaultExampleForAllTypes")
.Attr("s: string = 'foo'")
.Attr("i: int = 0")
.Attr("f: float = 1.0")
.Attr("b: bool = true")
.Attr("ty: type = DT_INT32")
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
.Attr("l_empty: list(int) = []")
.Attr("l_int: list(int) = [2, 3, 5, 7]");
Note in particular that the values of type type
use the DT_*
names for the types.
Polymorphism
Type Polymorphism
For ops that can take different types as input or produce different output types, you can specify an attr in an input or output type in the Op registration. Typically you would then register an OpKernel
for each supported type.
For instance, if you'd like the ZeroOut
Op to work on float
s in addition to int32
s, your Op registration might look like:
REGISTER_OP("ZeroOut")
.Attr("T: {float, int32}")
.Input("to_zero: T")
.Output("zeroed: T");
Your Op registration now specifies that the input's type must be float
, or int32
, and that its output will be the same type, since both have type T
.
A note on naming:{#naming} Inputs, outputs, and attrs generally should be given snake_case names. The one exception is attrs that are used as the type of an input or in the type of an input. Those attrs can be inferred when the op is added to the graph and so don't appear in the op's function. For example, this last definition of ZeroOut will generate a Python function that looks like:
def zero_out(to_zero, name=None): """... Args: to_zero: A `Tensor`. Must be one of the following types: `float32`, `int32`. name: A name for the operation (optional). Returns: A `Tensor`. Has the same type as `to_zero`. """
If
to_zero
is passed anint32
tensor, thenT
is automatically set toint32
(well, actuallyDT_INT32
). Those inferred attrs are given Capitalized or CamelCase names.Compare this with an op that has a type attr that determines the output type:
REGISTER_OP("StringToNumber") .Input("string_tensor: string") .Output("output: out_type") .Attr("out_type: {float, int32}"); .Doc(R"doc( Converts each string in the input Tensor to the specified numeric type. )doc");
In this case, the user has to specify the output type, as in the generated Python:
def string_to_number(string_tensor, out_type=None, name=None): """Converts each string in the input Tensor to the specified numeric type. Args: string_tensor: A `Tensor` of type `string`. out_type: An optional `tf.DType` from: `tf.float32, tf.int32`. Defaults to `tf.float32`. name: A name for the operation (optional). Returns: A `Tensor` of type `out_type`. """
#include "tensorflow/core/framework/op_kernel.h"
class ZeroOutInt32Op : public OpKernel {
// as before
};
class ZeroOutFloatOp : public OpKernel {
public:
explicit ZeroOutFloatOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<float>();
// Create an output tensor
Tensor* output = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<float>();
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value
if (N > 0) output_flat(0) = input(0);
}
};
// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the Op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
ZeroOutOpInt32);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ZeroOutFloatOp);
To preserve backwards compatibility, you should specify a default value when adding an attr to an existing op:
REGISTER_OP("ZeroOut") .Attr("T: {float, int32} = DT_INT32") .Input("to_zero: T") .Output("zeroed: T")
Lets say you wanted to add more types, say double
:
REGISTER_OP("ZeroOut")
.Attr("T: {float, double, int32}")
.Input("to_zero: T")
.Output("zeroed: T");
Instead of writing another OpKernel
with redundant code as above, often you will be able to use a C++ template instead. You will still have one kernel registration (REGISTER\_KERNEL\_BUILDER
call) per overload.
template <typename T>
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<T>();
// Create an output tensor
Tensor* output = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<T>();
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value
if (N > 0) output_flat(0) = input(0);
}
};
// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the Op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
ZeroOutOp<double>);
If you have more than a couple overloads, you can put the registration in a macro.
#include "tensorflow/core/framework/op_kernel.h"
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
Depending on the list of types you are registering the kernel for, you may be able to use a macro provided by tensorflow/core/framework/register_types.h
:
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
REGISTER_OP("ZeroOut")
.Attr("T: realnumbertype")
.Input("to_zero: T")
.Output("zeroed: T");
template <typename T>
class ZeroOutOp : public OpKernel { ... };
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
List Inputs and Outputs
In addition to being able to accept or produce different types, ops can consume or produce a variable number of tensors.
In the next example, the attr T
holds a list of types, and is used as the type of both the input in
and the output out
. The input and output are lists of tensors of that type (and the number and types of tensors in the output are the same as the input, since both have type T
).
REGISTER_OP("PolymorphicListExample")
.Attr("T: list(type)")
.Input("in: T")
.Output("out: T");
You can also place restrictions on what types can be specified in the list. In this next case, the input is a list of float
and double
tensors. The Op accepts, for example, input types (float, double, float)
and in that case the output type would also be (float, double, float)
.
REGISTER_OP("ListTypeRestrictionExample")
.Attr("T: list({float, double})")
.Input("in: T")
.Output("out: T");
If you want all the tensors in a list to be of the same type, you might do something like:
REGISTER_OP("IntListInputExample")
.Attr("N: int")
.Input("in: N * int32")
.Output("out: int32");
This accepts a list of int32
tensors, and uses an int
attr N
to specify the length of the list.
This can be made type polymorphic as well. In the next example, the input is a list of tensors (with length "N"
) of the same (but unspecified) type ("T"
), and the output is a single tensor of matching type:
REGISTER_OP("SameListInputExample")
.Attr("N: int")
.Attr("T: type")
.Input("in: N * T")
.Output("out: T");
By default, tensor lists have a minimum length of 1. You can change that default using a ">="
constraint on the corresponding attr. In this next example, the input is a list of at least 2 int32
tensors:
REGISTER_OP("MinLengthIntListExample")
.Attr("N: int >= 2")
.Input("in: N * int32")
.Output("out: int32");
The same syntax works with "list(type)"
attrs:
REGISTER_OP("MinimumLengthPolymorphicListExample")
.Attr("T: list(type) >= 3")
.Input("in: T")
.Output("out: T");
Inputs and Outputs
To summarize the above, an Op registration can have multiple inputs and outputs:
REGISTER_OP("MultipleInsAndOuts")
.Input("y: int32")
.Input("z: float")
.Output("a: string")
.Output("b: int32");
Each input or output spec is of the form:
<name>: <io-type-expr>
where <name>
begins with a letter and can be composed of alphanumeric characters and underscores. <io-type-expr>
is one of the following type expressions:
<type>
, where<type>
is a supported input type (e.g.float
,int32
,string
). This specifies a single tensor of the given type.See the list of supported Tensor types.
REGISTER_OP("BuiltInTypesExample") .Input("integers: int32") .Input("complex_numbers: scomplex64");
<attr-type>
, where<attr-type>
is the name of an Attr with typetype
orlist(type)
(with a possible type restriction). This syntax allows for polymorphic ops.REGISTER_OP("PolymorphicSingleInput") .Attr("T: type") .Input("in: T); REGISTER_OP("RestrictedPolymorphicSingleInput") .Attr("T: {int32, int64}") .Input("in: T);
Referencing an attr of type
list(type)
allows you to accept a sequence of tensors.REGISTER_OP("ArbitraryTensorSequenceExample") .Attr("T: list(type)") .Input("in: T") .Output("out: T"); REGISTER_OP("RestrictedTensorSequenceExample") .Attr("T: list({int32, int64})") .Input("in: T") .Output("out: T");
Note that the number and types of tensors in the output
out
is the same as in the inputin
, since both are of typeT
.For a sequence of tensors with the same type:
<number> * <type>
, where<number>
is the name of an Attr with typeint
. The<type>
can either be a specific type likeint32
orfloat
, or the name of an attr with typetype
. As an example of the first, this Op accepts a list ofint32
tensors:REGISTER_OP("Int32SequenceExample") .Attr("NumTensors: int") .Input("in: NumTensors * int32")
Whereas this Op accepts a list of tensors of any type, as long as they are all the same:
REGISTER_OP("SameTypeSequenceExample") .Attr("NumTensors: int") .Attr("T: type") .Input("in: NumTensors * T")
For a reference to a tensor:
Ref(<type>)
, where<type>
is one of the previous types.
A note on naming: Any attr used in the type of an input will be inferred. By convention those inferred attrs use capital names (like
T
orN
). Otherwise inputs, outputs, and attrs have names like function parameters (e.g.num_outputs
). For more details, see the earlier note on naming.
For more details, see tensorflow/core/framework/op_def_builder.h
.
Backwards compatibility
In general, changes to specifications must be backwards-compatible: changing the specification of an Op must not break prior serialized GraphDefs constructed from older specfications.
There are several ways to preserve backwards-compatibility.
Any new attrs added to an operation must have default values defined, and with that default value the Op must have the original behavior. To change an operation from not polymorphic to polymorphic, you must give a default value to the new type attr to preserve the original signature by default. For example, if your operation was:
REGISTER_OP("MyGeneralUnaryOp") .Input("in: float") .Output("out: float");
you can make it polymorphic in a backwards-compatible way using:
REGISTER_OP("MyGeneralUnaryOp") .Input("in: T") .Output("out: T") .Attr("T: numerictype = float");
You can safely make a constraint on an attr less restrictive. For example, you can change from
{int32, int64}
to{int32, int64, float}
or from{"apple", "orange"}
to{"apple", "banana", "orange"}
.Namespace any new Ops you create, by prefixing the Op names with something unique to your project. This avoids having your Op colliding with any Ops that might be included in future versions of Tensorflow.
Plan ahead! Try to anticipate future uses for the Op. Some signature changes can't be done in a compatible way (for example, adding an input, or making a single input into a list).
If you cannot make your change to an operation backwards compatible, then create a new operation with a new name with the new semantics.
GPU Support
You can implement different OpKernels and register one for CPU and another for GPU, just like you can register kernels for different types. There are several examples of kernels with GPU support in tensorflow/core/kernels/
. Notice some kernels have a CPU version in a .cc
file, a GPU version in a file ending in _gpu.cu.cc
, and some code shared in common in a .h
file.
For example, the pad
op has everything but the GPU kernel in tensorflow/core/kernels/pad_op.cc
. The GPU kernel is in tensorflow/core/kernels/pad_op_gpu.cu.cc
, and the shared code is a templated class defined in tensorflow/core/kernels/pad_op.h
. One thing to note, even when the GPU kernel version of pad
is used, it still needs its "paddings"
input in CPU memory. To mark that inputs or outputs are kept on the CPU, add a HostMemory()
call to the kernel registration, e.g.:
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("Pad") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("paddings"), \
PadOp<GPUDevice, T>)
Implement the gradient in Python
Given a graph of ops, TensorFlow uses automatic differentiation (backpropagation) to add new ops representing gradients with respect to the existing ops (see Gradient Computation). To make automatic differentiation work for new ops, you must register a gradient function which computes gradients with respect to the ops' inputs given gradients with respect to the ops' outputs.
Mathematically, if an op computes \(y = f(x)\) the registered gradient op converts gradients \(\partial / \partial y\) with respect to \(y\) into gradients \(\partial / \partial x\) with respect to \(x\) via the chain rule:
$$\frac{\partial}{\partial x} = \frac{\partial}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial}{\partial y} \frac{\partial f}{\partial x}.$$
In the case of ZeroOut
, only one entry in the input affects the output, so the gradient with respect to the input is a sparse "one hot" tensor. This is expressed as follows:
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
"""The gradients for `zero_out`.
Args:
op: The `zero_out` `Operation` that we are differentiating, which we can use
to find the inputs and outputs of the original op.
grad: Gradient with respect to the output of the `zero_out` op.
Returns:
Gradients with respect to the input of `zero_out`.
"""
to_zero = op.inputs[0]
shape = array_ops.shape(to_zero)
index = array_ops.zeros_like(shape)
first_grad = array_ops.reshape(grad, [-1])[0]
to_zero_grad = sparse_ops.sparse_to_dense(index, shape, first_grad, 0)
return [to_zero_grad] # List of one Tensor, since we have one input
Details about registering gradient functions with ops.RegisterGradient
:
For an op with one output, the gradient function will take an
Operation
op
and aTensor
grad
and build new ops out of the tensorsop.inputs[i]
,op.outputs[i]
, andgrad
. Information about any attrs can be found viaop.get_attr
.If the op has multiple outputs, the gradient function will take
op
andgrads
, wheregrads
is a list of gradients with respect to each output. The result of the gradient function must be a list ofTensor
objects representing the gradients with respect to each input.If there is no well-defined gradient for some input, such as for integer inputs used as indices, the corresponding returned gradient should be
None
. For example, for an op taking a floating point tensorx
and an integer indexi
, the gradient function wouldreturn [x_grad, None]
.If there is no meaningful gradient for the op at all, use
ops.NoGradient("OpName")
to disable automatic differentiation.
Note that at the time the gradient function is called, only the data flow graph of ops is available, not the tensor data itself. Thus, all computation must be performed using other tensorflow ops, to be run at graph execution time.
Implement a shape function in Python
The TensorFlow Python API has a feature called "shape inference" that provides information about the shapes of tensors without having to execute the graph. Shape inference is supported by "shape functions" that are registered for each op type, and perform two roles: asserting that the shapes of the inputs are compatible, and specifying the shapes for the outputs. A shape function is a Python function that takes an Operation
as input, and returns a list of TensorShape
objects (one per output of the op). To register a shape function, apply the tf.RegisterShape
decorator to a shape function. For example, the ZeroOut
op defined above would have a shape function like the following:
@tf.RegisterShape("ZeroOut"):
def _zero_out_shape(op):
"""Shape function for the ZeroOut op.
This is the unconstrained version of ZeroOut, which produces an output
with the same shape as its input.
"""
return [op.inputs[0].get_shape()]
A shape function can also constrain the shape of an input. For the version of ZeroOut
with a vector shape constraint, the shape function would be as follows:
@tf.RegisterShape("ZeroOut"):
def _zero_out_shape(op):
"""Shape function for the ZeroOut op.
This is the constrained version of ZeroOut, which requires the input to
have rank 1 (a vector).
"""
input_shape = op.inputs[0].get_shape().with_rank(1)
return [input_shape]
If your op is polymorphic with multiple inputs, use the properties of the operation to determine the number of shapes to check:
@tf.RegisterShape("IntListInputExample")
def _int_list_input_example_shape(op):
"""Shape function for the "IntListInputExample" op.
All inputs and the output are matrices of the same size.
"""
output_shape = tf.TensorShape(None)
for input in op.inputs:
output_shape = output_shape.merge_with(input.get_shape().with_rank(2))
return [output_shape]
Since shape inference is an optional feature, and the shapes of tensors may vary dynamically, shape functions must be robust to incomplete shape information for any of the inputs. The merge_with
method allows the caller to assert that two shapes are the same, even if either or both of them do not have complete information. Shape functions are defined for all of the standard Python ops, and provide many different usage examples.