Skip to content
Permalink
Browse files

Raise error if validation steps exists without legitimate

validation data.

PiperOrigin-RevId: 238737301
  • Loading branch information...
tanzhenyu authored and tensorflower-gardener committed Mar 16, 2019
1 parent 82dfc66 commit 216932e8972f3bd9c017d900a87b287d040adabe
Showing with 18 additions and 12 deletions.
  1. +3 −4 tensorflow/python/keras/engine/training.py
  2. +15 −8 tensorflow/python/keras/engine/training_test.py
@@ -829,11 +829,10 @@ def _worker_fn(_):
y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
sample_weights, val_sample_weights = (slice_arrays(
sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
elif validation_steps:
val_x = []
val_y = []
val_sample_weights = []
else:
if validation_steps:
raise ValueError('`validation_steps` should not be specified if '
'`validation_data` is None.')
val_x = None
val_y = None
val_sample_weights = None
@@ -973,6 +973,18 @@ def on_test_begin(self, logs=None):
callbacks=[val_counter])
self.assertEqual(val_counter.val_runs, expected_runs)

@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_validation_steps_without_data(self):
x, y = np.ones((10, 10)), np.ones((10, 1))
model = testing_utils.get_small_mlp(2, 1, 10)
model.compile('sgd', 'mse')

with self.assertRaisesRegexp(
ValueError, '`validation_steps` should not be specified if '
'`validation_data` is None.'):
model.fit(x, y, epochs=4, validation_data=None, validation_steps=3)

@keras_parameterized.run_all_keras_modes
def test_add_loss_correctness(self):
class Bias(keras.layers.Layer):
@@ -2085,9 +2097,9 @@ def test_model_with_external_loss(self):
out = model.fit(None, None, epochs=1,
steps_per_epoch=None,
validation_steps=2)
out = model.fit(None, None, epochs=1,
steps_per_epoch=2,
validation_steps=2)
out = model.fit(None, None, epochs=1,
steps_per_epoch=2,
validation_steps=2)

# test evaluate
with self.assertRaises(ValueError):
@@ -2122,11 +2134,6 @@ def test_model_with_external_loss(self):
out = model.fit(None, None, epochs=1, batch_size=10)
out = model.fit(None, None, epochs=1, steps_per_epoch=1)

# test fit with validation data
out = model.fit(None, None, epochs=1,
steps_per_epoch=2,
validation_steps=2)

# test evaluate
with self.assertRaises(ValueError):
out = model.evaluate(None, None, batch_size=10)

0 comments on commit 216932e

Please sign in to comment.
You can’t perform that action at this time.