First, let’s look at the variable used in the equations.
Now, let’s look at the equations introduced in Section 2.1: Matching logits is a special case of distillation.
The equation of cross-entropy for distillation is as follows:
The gradient is given by the equation 2 of the paper
But how did we end up here? Let's derive it step by step.
Step 1: Take a log on qi
Since
we have:
Step 2: Compute partial derivatives of log qk w.r.t zi
For any class k:
Case 1: When k = i:
Case 2: When k != i:
Step 3: Compute the derivative of the log-partition function
Let Z = ∑j exp(zj/T). Then:
Step 4: Substitute back
For k = i:
For k != i:
Step 5: Compute the full gradient
This gives us equation (2):
Let’s focus on equation (3) from the paper. It’s given by:
I didn’t really know how to proceed with this part of the equation. With a little bit of digging, I understood how the authors reached equation (3) from equation (2). Turns out it was pretty simple.
When T is large (high temperature regime), we can approximate the exponentials using Taylor series. (This didn’t strike me at first glance. Sorry if you think this is obvious)
Step 1: Taylor expansion of exponentials
For small x,
When T is large, zi/T is small. Hence, we can only keep the first order (linear terms). so:
Step 2: Apply to softmax probabilities
The denominator becomes:
Therefore:
Similarly, for the teacher model:
Step 3: Substitute into the gradient and we get equation (3)
Now, the last and final equation (4). The paper assumes “the logits have been zero-meaned separately for each example ”. so:
Substituting the above equations in equation (3):
This is equation (4):
One last thing, which is mentioned in the paper, is:
So in the high temperature limit, distillation is equivalent to minimizing 1/2(zi − vi )^2 , provided the logits are zero-meaned separately for each example.
The formula of L2 (squared) loss and its gradient is:
Comparing with equation (4), we get:
This means that in the high-temperature limit with zero-meaned logits, minimizing cross-entropy distillation is equivalent to minimizing L2 loss between logits.
If you’ve made it this far, I hope this walkthrough made things a little clearer and a little less intimidating. There is another part to the paper that involves KL divergence, which I will include later.
References: