8

I have a custom tf.keras.layers.Layer which do some kind of bit unpacking (converting integers to booleans values (0 or 1 float)) using only TF operators.

class CharUnpack(keras.layers.Layer):

    def __init__(self, name="CharUnpack", *args, **kwargs):
        super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
        # Range [7, 6, ..., 0] to bit-shift integers
        self._shifting_range = tf.reshape(
            tf.dtypes.cast(
                tf.range(7, -1, -1, name='shifter_range'),
                tf.uint8,
                name='shifter_cast'),
            (1, 1, 8),
            name='shifter_reshape')
        # Constant value 0b00000001 to use as bitwise and operator
        self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')

    def call(self, inputs):
        return tf.dtypes.cast(
            tf.reshape(
                tf.bitwise.bitwise_and(
                    tf.bitwise.right_shift(
                        tf.expand_dims(inputs, 2),
                        self._shifting_range,
                    ),
                    self._selection_bit,
                ),
                [x if x else -1 for x in self.compute_output_shape(inputs.shape)]
            ),
            tf.float32
        )

    def compute_output_shape(self, input_shape):
        try:
            if len(input_shape) > 1:
                output_shape = tf.TensorShape(tuple(list(input_shape[:-1]) + [input_shape[-1] * 8]))
            else:
                output_shape = tf.TensorShape((input_shape[0] * 8,))
        except TypeError:
            output_shape = input_shape
        return output_shape

    def compute_output_signature(self, input_signature):
        return tf.TensorSpec(self.compute_output_shape(input_signature.shape), tf.float32)

I tried to benchmark this layer to improve the time performance as shown in this TF guide.

inputs = tf.zeros([64, 400], dtype=tf.uint8)

eager = CharUnpack()

@tf.function
def fun(x):
    eager(x)

# Warm-up
eager(inputs)
fun(inputs)

print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
Function: 0.01062483999885444
Eager: 0.12658399900101358

As you can see, I can get a 10 times speed-up!!! So, I added the @tf.function decorator to my CharUnpack.call method:

+    @tf.function
     def call(self, inputs):
         return tf.dtypes.cast(

Now I expect both, the eager and the fun, calls to spend similar time, but I get no improvement.

Function: 0.009667591999459546
Eager: 0.10346330100037449

Moreover, in section 2.1 of this SO answer states that Models are graph-compiled by default (which should be logic), but this does not seem to be the case...

How to properly use the @tf.function decorator to make my layer always graph-compiled?

AlexisBRENON
  • 2,758
  • 2
  • 17
  • 25
  • Hi! Have you solved your issue? Does tf function improve performance of custom layer?? – cappadavide Nov 22 '20 at 09:43
  • @cappadavide I did not play with TF since then and did not investigate any further. For sure, using `tf.function` can improve the performance, but I don't know the canonical way to use it… Feel free to do some experiments with newer versions of TF and please, share your insights ! – AlexisBRENON Nov 22 '20 at 12:19

1 Answers1

0

First, tf.function does not need nested using, i.e., you can only wrap your custom train_step() (contain propagation). In this case, there is no need to wrap inner layer or sub model's call() function, since they are involved in your train_step. Nested useage may lead to some unexpected performance degradation.

Second, any computational acceleration comes at a cost, tf.function is a way of exchanging space for time and need initialization to build Graph. So, when benchmark, we should re-run the same function for several times, since a secondary call of tf.function do not cost building time as long as Tracing changes nothing.

for _ in range(5):
    print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
for _ in range(5):
    print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
# Function: 0.02040819999820087
# Function: 0.020311099986429326
# Function: 0.020155799997155555
# Function: 0.02004839999426622
# Function: 0.019999900003313087
# Eager: 0.035980800006655045
# Eager: 0.035652499995194376
# Eager: 0.035596200003055856
# Eager: 0.03490520000923425
# Eager: 0.03762050000659656
Little Train
  • 255
  • 9