A easy guide to gradients of softmax cross entropy loss

A easy guide to gradients of softmax cross entropy loss

Softmax function

The softmax function, generally in neural networks, is pervasively used as the output layer of a classification task. In fact, you can think of softmax as outputting the probability of several category selections. For example, if I have a classification task to be divided into three categories, the softmax function can output the probability of three category selections based on their relative sizes, and the probability sums to 1. The input to softmax function is often called logits z_i. The function comes in the following form.

$$
a_i = \frac{e^{z_i}}{\sum_k e^{z_k}}
$$
So it’s easy to see that this function is a multiple-input-multiple-output mapping.

Cross-entropy loss function

The loss function can take many forms, and the cross-entropy function is used here mainly because this derivative is relatively simple and easy to compute, and cross-entropy solves the problem of slow learning of certain loss functions. The cross-entropy function looks like,
$$
L(z_i,y_i) = -\sum_iy_ilna_i
$$
where y_iis so called labels standing for the true category each sample input falls into. The loss Lis a multivariant function, thus the gradients would flow into both logits and labels. But since labels are not usually trainable variables, to disallow back propagation into labels, it’s practical to pass label tensors through tf.stop_gradient before feeding it to this function in Tensorflow.

Chain Rule

To obtain the gradient of loss w.r.t. z_i, simply apply the chain rule resulting,
$$
\frac{\partial L}{\partial z_i} = \frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}
$$

After necessary derivations(ref), we easily landed on
$$
\frac{\partial L}{\partial a_j}=-\frac{y_i}{a_i}, \frac{\partial a_j}{\partial z_i}=a_i(\delta_{ij}-a_j)
$$
Plug back to the gradient of loss w.r.t. z_i,
$$
\frac{\partial L}{\partial z_i}=a_i-y_i
$$

Story not finished

In Tensorflow, the computation of gradient for op tf.nn.softmax_cross_entropy_with_logits is implemented here.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
"""Gradient function for SoftmaxCrossEntropyWithLogits."""
# grad_loss is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# grad_grad is the backprop for softmax gradient.
#
# Second derivative is just softmax derivative w.r.t. logits.
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)

logits = op.inputs[0]
if (grad_grad is not None and
not getattr(grad_grad, "_is_zeros_tensor", False)):
softmax = nn_ops.softmax(logits)

grad += ((grad_grad - array_ops.squeeze(
math_ops.matmul(
array_ops.expand_dims(grad_grad, 1),
array_ops.expand_dims(softmax, 2)),
axis=1)) * softmax)

return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))

Why does it add another layer of complexity to what we discuss above? We can find the answer in chain rule again.

In the forward path, the gradient of loss w.r.t. logits has been computed and stored as input[1]. So the op actually takes 2 inputs(logits and labels) and returns 2 outputs (loss and gradient of loss w.r.t. logits ). Therefore, both of these outputs need to backpropagate into logits. grad_loss and grad_grad can be understood as
$$
GL=\frac{\partial \cdot}{\partial L}, GG =\frac{\partial \cdot}{\partial \frac{\partial L}{\partial z_i}}
$$

To obtain the gradient of loss w.r.t. logits, a compound chain rule as needed.
$$
\frac{\partial \cdot}{\partial z_i} = GL\cdot\frac{\partial L}{\partial z_i}+GG \cdot \frac{\partial ^2 L}{\partial z_i^2}, \frac{\partial ^2 L}{\partial z_i^2} = \frac{\partial a_i}{\partial z_i}=(1-a_i)a_i
$$
These are the additional terms appended by grad +=....

Conclusion

In really application, grad_grad may not even be used or passed in, but it’s necessary to know the details so as not to miss any possibilities.

  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.
  • Copyrights © 2021 Shysie
  • Visitors: | Views:

请我喝杯咖啡吧~