Skip to content
Permalink
Browse files

Fix a bug in the robust gradient for div. Add robust gradient for rea…

…l_div.

PiperOrigin-RevId: 238719379
  • Loading branch information...
tensorflower-gardener committed Mar 15, 2019
1 parent c0886e7 commit 7ea3243faca6d183ab6eceebdd578df25e6ab002
Showing with 53 additions and 35 deletions.
  1. +1 −0 tensorflow/python/kernel_tests/cwise_ops_test.py
  2. +52 −35 tensorflow/python/ops/math_grad.py
@@ -1139,6 +1139,7 @@ def testGradientAtSingularity(self):
(gen_math_ops.acos, (1.,)),
(gen_math_ops.atan2, (0., 0.)),
(gen_math_ops.div, (1., 0.)),
(gen_math_ops.real_div, (1., 0.)),
(math_ops.pow, (0., -1.)),
]
for op, singularity in ops_and_singularity:
@@ -177,8 +177,8 @@ def _ProdGrad(op, grad):
left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
# For complex inputs, the gradient is in the conjugate direction.
y = array_ops.reshape(math_ops.conj(left) * math_ops.conj(right),
permuted_shape)
y = array_ops.reshape(
math_ops.conj(left) * math_ops.conj(right), permuted_shape)

# Invert the transpose and reshape operations.
# Make sure to set the statically known shape information through a reshape.
@@ -261,8 +261,8 @@ def _SegmentMinOrMaxGrad(op, grad):
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
op.inputs[1])
num_selected = math_ops.segment_sum(
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
@@ -282,9 +282,13 @@ def _SegmentMaxGrad(op, grad):
return _SegmentMinOrMaxGrad(op, grad)


def _GatherDropNegatives(params, ids, zero_clipped_indices=None,
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
""" Helper function for unsorted segment ops. Gathers params for
""" Helper function for unsorted segment ops.
Gathers params for
positive segment ids and gathers 0 for inputs with negative segment id.
Also returns the clipped indices and a boolean mask with the same shape
as ids where a positive id is masked as true. With this, the latter two
@@ -300,8 +304,8 @@ def _GatherDropNegatives(params, ids, zero_clipped_indices=None,
# todo(philjd): remove this if tf.where supports broadcasting (#9284)
for _ in range(gathered.shape.ndims - is_positive.shape.ndims):
is_positive = array_ops.expand_dims(is_positive, -1)
is_positive = (is_positive &
array_ops.ones_like(gathered, dtype=dtypes.bool))
is_positive = (
is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool))
# replace gathered params of negative indices with 0
zero_slice = array_ops.zeros_like(gathered)
return (array_ops.where(is_positive, gathered, zero_slice),
@@ -321,8 +325,7 @@ def _UnsortedSegmentMinOrMaxGrad(op, grad):
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
zero_clipped_indices,
is_positive)
zero_clipped_indices, is_positive)
zeros = array_ops.zeros_like(gathered_grads)
return array_ops.where(is_selected, gathered_grads, zeros), None, None

@@ -348,6 +351,7 @@ def _UnsortedSegmentMinGrad(op, grad):
@ops.RegisterGradient("UnsortedSegmentProd")
def _UnsortedSegmentProdGrad(op, grad):
""" Gradient for UnsortedSegmentProd.
The gradient can be expressed for each segment by dividing the segment's
product by each element of the segment input tensor, but this approach can't
deal with zeros in the input.
@@ -368,19 +372,18 @@ def _UnsortedSegmentProdGrad(op, grad):
math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
# handle case 3 and set the gradient to 0 for segments with more than one
# 0 as input
grad = array_ops.where(math_ops.greater(num_zeros, 1),
array_ops.zeros_like(grad), grad)
grad = array_ops.where(
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
# replace all zeros with ones and compute the unsorted_segment_prod
non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]),
op.inputs[0])
non_zero_prod = gen_math_ops.unsorted_segment_prod(
non_zero_data, op.inputs[1], op.inputs[2])
non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
op.inputs[1], op.inputs[2])
# clip the indices for gather to be positive
zero_clipped_indices = math_ops.maximum(op.inputs[1],
array_ops.zeros_like(op.inputs[1]))
gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
gathered_non_zero_prod = array_ops.gather(non_zero_prod,
zero_clipped_indices)
gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf.
# Now fetch the individual results for segments containing 0 and those that
# don't. is_zero will also fetch results for entries with negative index
@@ -714,8 +717,8 @@ def _IgammaGrad(op, grad):
partial_a = gen_math_ops.igamma_grad_a(a, x)
# Perform operations in log space before summing, because Gamma(a)
# and Gamma'(a) can grow large.
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x)
- math_ops.lgamma(a))
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
math_ops.lgamma(a))
return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))

@@ -983,17 +986,15 @@ def _MulNoNanGrad(op, grad):
y = op.inputs[1]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(
x, grad)
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry),
sy))
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy))


@ops.RegisterGradient("Div")
@@ -1007,13 +1008,19 @@ def _DivGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
if compat.forward_compatible(2019, 4, 7):
div_op = math_ops.div_no_nan
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(
math_ops.divide(math_ops.divide(-x, y), y), grad), ry),
sy))
else:
div_op = math_ops.divide
return (array_ops.reshape(math_ops.reduce_sum(div_op(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(grad * div_op(math_ops.divide(-x, y), y), ry),
sy))
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.divide(math_ops.divide(-x, y), y), ry), sy))


@ops.RegisterGradient("FloorDiv")
@@ -1053,11 +1060,21 @@ def _RealDivGrad(op, grad):
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
if compat.forward_compatible(2019, 4, 7):
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(
math_ops.realdiv(math_ops.realdiv(-x, y), y), grad),
ry), sy))
else:
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry),
sy))


@ops.RegisterGradient("DivNoNan")
@@ -1359,8 +1376,8 @@ def _ComplexAbsGrad(op, grad):
"""Returns the gradient of ComplexAbs."""
# TODO(b/27786104): The cast to complex could be removed once arithmetic
# supports mixtures of complex64 and real values.
return (math_ops.complex(grad, array_ops.zeros_like(grad)) * math_ops.sign(
op.inputs[0]))
return (math_ops.complex(grad, array_ops.zeros_like(grad)) *
math_ops.sign(op.inputs[0]))


@ops.RegisterGradient("Cast")

0 comments on commit 7ea3243

Please sign in to comment.
You can’t perform that action at this time.