Illustration of Knowledge Distillation

Mathematical Derivation

First, let’s look at the variable used in the equations.

Now, let’s look at the equations introduced in Section 2.1: Matching logits is a special case of distillation.

The equation of cross-entropy for distillation is as follows:
The gradient is given by the equation 2 of the paper

But how did we end up here? Let's derive it step by step.

Step 1: Take a log on qi

Since

we have:

Step 2: Compute partial derivatives of log qk w.r.t zi

For any class k:

Case 1: When k = i:
Case 2: When k != i:

Step 3: Compute the derivative of the log-partition function

Let Z = ∑j​ exp(zj​/T). Then:

Step 4: Substitute back

For k = i:
For k != i:

Step 5: Compute the full gradient

This gives us equation (2):
Let’s focus on equation (3) from the paper. It’s given by:

I didn’t really know how to proceed with this part of the equation. With a little bit of digging, I understood how the authors reached equation (3) from equation (2). Turns out it was pretty simple.

When T is large (high temperature regime), we can approximate the exponentials using Taylor series. (This didn’t strike me at first glance. Sorry if you think this is obvious)

Step 1: Taylor expansion of exponentials

For small x,
When T is large, zi/T is small. Hence, we can only keep the first order (linear terms). so:

Step 2: Apply to softmax probabilities

The denominator becomes:

Therefore:

Similarly, for the teacher model:

Step 3: Substitute into the gradient and we get equation (3)
Now, the last and final equation (4). The paper assumes “the logits have been zero-meaned separately for each example ”. so:
Substituting the above equations in equation (3):
This is equation (4):

One last thing, which is mentioned in the paper, is:

So in the high temperature limit, distillation is equivalent to minimizing 1/2(zi − vi )^2 , provided the logits are zero-meaned separately for each example.

The formula of L2 (squared) loss and its gradient is:

Comparing with equation (4), we get:

This means that in the high-temperature limit with zero-meaned logits, minimizing cross-entropy distillation is equivalent to minimizing L2 loss between logits.

If you’ve made it this far, I hope this walkthrough made things a little clearer and a little less intimidating. There is another part to the paper that involves KL divergence, which I will include later.

References:

  1. Distilling the Knowledge in a Neural Network https://arxiv.org/pdf/1503.02531.pdf