Skip to content
Permalink
Browse files

Always wrap non-sequence types like maps in a list when tracing Model…

… call

Avoids input_signature and unwrapping issues.

Fixes #26591 (again).

PiperOrigin-RevId: 238725001
  • Loading branch information...
allenlavoie authored and tensorflower-gardener committed Mar 15, 2019
1 parent fed677e commit 4f086f4c0fc5547bf91d0036ccecf33d69c1303a
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

import collections

from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.util import nest
@@ -82,8 +84,13 @@ def trace_model_call(model, input_signature=None):
input_specs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_input_specs)
# The input signature of the call function is a list with one element, since
# all tensor inputs must be passed in as the first argument.
input_signature = [input_specs] if len(input_specs) > 1 else input_specs
# all tensor inputs must be passed in as the first argument. Single-element
# dictionaries and other non-sequence types must also be wrapped.
if (len(input_specs) > 1
or not isinstance(input_specs, collections.Sequence)):
input_signature = [input_specs]
else:
input_signature = input_specs

# TODO(mdan): Should the model's call be autographed by default?
@def_function.function(input_signature=input_signature, autograph=False)
@@ -28,12 +28,15 @@
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -130,6 +133,27 @@ def test_trace_multi_io_model_outputs(self):

self._assert_all_close(expected_outputs, signature_outputs)

@keras_parameterized.run_all_keras_modes
def test_trace_features_layer(self):
columns = [feature_column_v2.numeric_column('x')]
model = sequential.Sequential(
[feature_column_v2.DenseFeatures(columns)])
model_input = {'x': constant_op.constant([[1.]])}
model.predict(model_input, steps=1)
fn = saving_utils.trace_model_call(model)
self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]}))

columns = [feature_column_v2.numeric_column('x'),
feature_column_v2.numeric_column('y')]
model = sequential.Sequential(
[feature_column_v2.DenseFeatures(columns)])
model_input = {'x': constant_op.constant([[1.]]),
'y': constant_op.constant([[2.]])}
model.predict(model_input, steps=1)
fn = saving_utils.trace_model_call(model)
self.assertAllClose({'output_1': [[1., 2.]]},
fn({'x': [[1.]], 'y': [[2.]]}))

@keras_parameterized.run_all_keras_modes
def test_specify_input_signature(self):
model = testing_utils.get_small_sequential_mlp(10, 3, None)
@@ -1186,6 +1186,18 @@ def test_dense_features_layer(self, cycles):
**model_input).values()
self.assertAllClose([[1., 2.]], signature_output)

def test_dense_features_layer_fit(self, cycles):
columns = [feature_column_v2.numeric_column("x")]
model = sequential.Sequential(
[feature_column_v2.DenseFeatures(columns),
core.Dense(1)])
model_input = {"x": constant_op.constant([[1.]])}
model.compile(optimizer="adam", loss="mse")
model.fit(model_input, constant_op.constant([[3.]]))
loaded = self.cycle(model, cycles)
loaded._default_save_signature(model_input)
loaded.signatures["serving_default"](**model_input)


class SingleCycleTests(test.TestCase, parameterized.TestCase):

0 comments on commit 4f086f4

Please sign in to comment.
You can’t perform that action at this time.