Everyone's AI
Machine learningAI Papers
Loading...

Learn

🏅My achievements

Chapter 07

Chain Rule: Unraveling Composite Functions, the Heart of Backprop

When you differentiate a function inside another, multiply outer derivative × inner derivative. That's the core of backprop.

Math diagram by chapter

Select a chapter to see its diagram below. View the flow of basic math at a glance.

A nested function is a chain xxx → inner → outer → yyy. Multiply outer derivative × inner derivative to get the total derivative.

012036xu(1, 3)u = g(x) = 2x+1g'(x)=2u0350925uy(3, 9)y = f(u) = u²f'(u)=2uInner derivativeOuter derivative

Example: calculation order (one step highlighted at a time)

1.Example: as in the graphs above, u=g(x)=2x+1u = g(x) = 2x+1u=g(x)=2x+1 and y=f(u)=u2y = f(u) = u^2y=f(u)=u2, so y=(2x+1)2y = (2x+1)^2y=(2x+1)2. Differentiate with respect to xxx.
2.① Inner derivative (left graph): u=g(x)=2x+1u = g(x) = 2x+1u=g(x)=2x+1 → derivative w.r.t. xxx is 222
3.② Outer derivative (right graph): y=f(u)=u2y = f(u) = u^2y=f(u)=u2 → derivative w.r.t. uuu is 2u=2(2x+1)2u = 2(2x+1)2u=2(2x+1)
4.③ Multiply: 2×2(2x+1)=4(2x+1)2 \times 2(2x+1) = 4(2x+1)2×2(2x+1)=4(2x+1) → answer

As the dot moves along the chain, rates multiply along the way. Backprop is the same: multiply at each step.

What is the chain rule?

The chain rule is the rule for differentiating composite functions—functions inside other functions. Like peeling an onion: differentiate the outer function (f′f^{\prime}f′) and multiply by the derivative of the inner (g′g'g′). In symbols: dydx=dydu⋅dudx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}dxdy​=dudy​⋅dxdu​. It is like finding the gear ratio of meshed gears.
Intuition: You (xxx) push a friend (uuu), and the friend pushes a cart (yyy). If you push the friend 2× harder and the friend pushes the cart 3× harder, the cart moves 2×3=6× as much as your push. The chain rule is this multiplication of rates along the chain.
One-line summary: To differentiate a composite function w.r.t. xxx, multiply the outer derivative and the inner derivative. See the table below for the steps.
  • Step1
  • TaskIdentify inner and outer
  • Example: y=(2x+1)2y=(2x+1)^2y=(2x+1)2Inner u=2x+1u=2x+1u=2x+1, outer y=u2y=u^2y=u2
  • Step2
  • TaskOuter derivative — differentiate outer (keep uuu as is)
  • Example: y=(2x+1)2y=(2x+1)^2y=(2x+1)2u2u^2u2 → 2u2u2u
  • Step3
  • TaskInner derivative — differentiate inner w.r.t. xxx
  • Example: y=(2x+1)2y=(2x+1)^2y=(2x+1)22x+12x+12x+1 → 222
  • Step4
  • TaskMultiply
  • Example: y=(2x+1)2y=(2x+1)^2y=(2x+1)22u×2=2(2x+1)×2=4(2x+1)2u \times 2 = 2(2x+1) \times 2 = 4(2x+1)2u×2=2(2x+1)×2=4(2x+1)
StepTaskExample: y=(2x+1)2y=(2x+1)^2y=(2x+1)2
1Identify inner and outerInner u=2x+1u=2x+1u=2x+1, outer y=u2y=u^2y=u2
2Outer derivative — differentiate outer (keep uuu as is)u2u^2u2 → 2u2u2u
3Inner derivative — differentiate inner w.r.t. xxx2x+12x+12x+1 → 222
4Multiply2u×2=2(2x+1)×2=4(2x+1)2u \times 2 = 2(2x+1) \times 2 = 4(2x+1)2u×2=2(2x+1)×2=4(2x+1)
Main formula: dydx=dydu⋅dudx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}dxdy​=dudy​⋅dxdu​ or (f∘g)′(x)=f′(g(x))⋅g′(x)(f \circ g)'(x) = f^{\prime}(g(x)) \cdot g'(x)(f∘g)′(x)=f′(g(x))⋅g′(x). As in the visual, xxx → inner → outer → yyy, so multiply the derivative on each segment. If the inner is itself composite, apply the same: outer derivative × inner derivative, and repeat.
Why multiply instead of add? Because these are rates. A car at 100 km/h (vvv) and an exchange rate of 1300 won per dollar (rrr) cannot be added meaningfully. To compute amplification or damping of change, you must multiply.
Check with numbers: For y=(2x+1)2y=(2x+1)^2y=(2x+1)2 at x=1x=1x=1, the formula gives 4(2(1)+1)=124(2 (1) +1)=124(2(1)+1)=12. If xxx goes from 1 to 1.01 (change 0.01), yyy goes from 9 to about 9.1204 (change about 0.12). So the rate is 12, which matches.
Deep learning models are huge composite functions—dozens or hundreds of functions stacked (y=fn(...f2(f1(x))...)y = f_n(...f_2(f_1(x))...)y=fn​(...f2​(f1​(x))...)). We need to know how the final loss (LLL) changes when we change the initial input or a weight (www) in the middle. That requires the chain rule.
Backpropagation is exactly the chain rule in action. When we propagate the error from the output layer backward to the input, we multiply the derivative at each layer. Without this, training deep networks would be impossible.
So AI learning is passing derivative values along by multiplying them (the chain rule). The deeper the network, the more this multiplication repeats. Multiplying numbers less than 1 (e.g. 0.5) many times drives the product toward 0. This vanishing gradient was one reason deep networks were hard to train. Techniques like ReLU and skip connections help mitigate it.
It is used to analyze complex cause-and-effect chains. If A affects B and B affects C, the effect of A on C is found by multiplying the effect at each step.
  • SituationCost → output → time
  • What we findEffect of time on cost
  • Chain rule (total rate)(cost/output) ×\times× (output/time)
  • SituationVolume → radius → time
  • What we findHow fast volume changes when blowing up a balloon
  • Chain rule (total rate)(volume/radius) ×\times× (radius/time)
  • SituationError → output → weight
  • What we findAI learning: how much to update the weight
  • Chain rule (total rate)(error/output) ×\times× (output/weight)
SituationWhat we findChain rule (total rate)
Cost → output → timeEffect of time on cost(cost/output) ×\times× (output/time)
Volume → radius → timeHow fast volume changes when blowing up a balloon(volume/radius) ×\times× (radius/time)
Error → output → weightAI learning: how much to update the weight(error/output) ×\times× (output/weight)
Automatic differentiation: Frameworks like PyTorch or TensorFlow compute derivatives when we call `loss.backward()`. Inside, they build a computation graph and apply the chain rule at each node to compute and multiply gradients in an instant.
For a nested function, treat the inner part as one block, then multiply the derivative of the outer (with respect to that block) by the derivative of the inner. If the inner is itself nested, repeat. Tip: Set inner = something, differentiate the outer only, then multiply by the derivative of the inner w.r.t. xxx.
Simplest example: y=(3x)2y=(3x)^2y=(3x)2. Inner u=3xu=3xu=3x → derivative 333. Outer u2u^2u2 → derivative 2u=2⋅3x2u=2\cdot 3x2u=2⋅3x. Product: 3×2⋅3x=18x3 \times 2\cdot 3x = 18x3×2⋅3x=18x. At x=2x=2x=2 the slope is 363636.
Easy to varied examples are in the table below. In each row, multiply inner and outer derivatives to get the answer.
  • ProblemEasy y=(3x)2y=(3x)^2y=(3x)2
  • SolutionInner u=3xu=3xu=3x → inner deriv 333, outer u2u^2u2 → outer deriv 2u2u2u; product 2⋅3x⋅3=18x2\cdot 3x\cdot 3=18x2⋅3x⋅3=18x
  • ProblemEasy y=x+1y=\sqrt{x+1}y=x+1​
  • SolutionInner u=x+1u=x+1u=x+1 → inner deriv 111, outer u\sqrt{u}u​ → outer deriv 1/(2u)1/(2\sqrt{u})1/(2u​); product 1/(2x+1)1/(2\sqrt{x+1})1/(2x+1​)
  • ProblemEx. y=(2x+1)5y=(2x+1)^5y=(2x+1)5
  • SolutionInner deriv 222, outer deriv 5(2x+1)45(2x+1)^45(2x+1)4 → product 10(2x+1)410(2x+1)^410(2x+1)4
  • ProblemEx. y=ex2y=e^{x^2}y=ex2
  • SolutionInner deriv 2x2x2x, outer deriv ex2e^{x^2}ex2 → product 2x ex22x\,e^{x^2}2xex2
  • ProblemEx. y=sin⁡(2x)y=\sin(2x)y=sin(2x)
  • SolutionInner u=2xu=2xu=2x → inner deriv 222, outer sin⁡u\sin usinu → outer deriv cos⁡u\cos ucosu; product 2cos⁡(2x)2\cos(2x)2cos(2x)
  • ProblemEx. y=e3xy=e^{3x}y=e3x
  • SolutionInner deriv 333, outer deriv e3xe^{3x}e3x → product 3e3x3e^{3x}3e3x
  • ProblemEx. y=ln⁡(sin⁡x)y=\ln(\sin x)y=ln(sinx)
  • SolutionInner deriv cos⁡x\cos xcosx, outer deriv 1/sin⁡x1/\sin x1/sinx → product cos⁡x/sin⁡x=cot⁡x\cos x/\sin x=\cot xcosx/sinx=cotx
ProblemSolution
Easy y=(3x)2y=(3x)^2y=(3x)2Inner u=3xu=3xu=3x → inner deriv 333, outer u2u^2u2 → outer deriv 2u2u2u; product 2⋅3x⋅3=18x2\cdot 3x\cdot 3=18x2⋅3x⋅3=18x
Easy y=x+1y=\sqrt{x+1}y=x+1​Inner u=x+1u=x+1u=x+1 → inner deriv 111, outer u\sqrt{u}u​ → outer deriv 1/(2u)1/(2\sqrt{u})1/(2u​); product 1/(2x+1)1/(2\sqrt{x+1})1/(2x+1​)
Ex. y=(2x+1)5y=(2x+1)^5y=(2x+1)5Inner deriv 222, outer deriv 5(2x+1)45(2x+1)^45(2x+1)4 → product 10(2x+1)410(2x+1)^410(2x+1)4
Ex. y=ex2y=e^{x^2}y=ex2Inner deriv 2x2x2x, outer deriv ex2e^{x^2}ex2 → product 2x ex22x\,e^{x^2}2xex2
Ex. y=sin⁡(2x)y=\sin(2x)y=sin(2x)Inner u=2xu=2xu=2x → inner deriv 222, outer sin⁡u\sin usinu → outer deriv cos⁡u\cos ucosu; product 2cos⁡(2x)2\cos(2x)2cos(2x)
Ex. y=e3xy=e^{3x}y=e3xInner deriv 333, outer deriv e3xe^{3x}e3x → product 3e3x3e^{3x}3e3x
Ex. y=ln⁡(sin⁡x)y=\ln(\sin x)y=ln(sinx)Inner deriv cos⁡x\cos xcosx, outer deriv 1/sin⁡x1/\sin x1/sinx → product cos⁡x/sin⁡x=cot⁡x\cos x/\sin x=\cot xcosx/sinx=cotx
Problem types and how to solve
  • TypePower
  • Form(g(x))n(g(x))^n(g(x))n
  • How to get f′(x)f^{\prime}(x)f′(x)Outer deriv nun−1n u^{n-1}nun−1 × inner deriv g′(x)g'(x)g′(x).
  • TypeExponential
  • Formeg(x)e^{g(x)}eg(x)
  • How to get f′(x)f^{\prime}(x)f′(x)Outer deriv eue^ueu × inner deriv → eg(x)⋅g′(x)e^{g(x)} \cdot g'(x)eg(x)⋅g′(x).
  • TypeTrig
  • Formsin⁡(g(x))\sin(g(x))sin(g(x)), cos⁡(g(x))\cos(g(x))cos(g(x))
  • How to get f′(x)f^{\prime}(x)f′(x)Outer deriv (cos or −sin) × inner deriv.
  • TypeRoot
  • Formg(x)\sqrt{g(x)}g(x)​
  • How to get f′(x)f^{\prime}(x)f′(x)Outer deriv 1/(2u)1/(2\sqrt{u})1/(2u​) × inner deriv.
  • TypeLog
  • Formln⁡(g(x))\ln(g(x))ln(g(x))
  • How to get f′(x)f^{\prime}(x)f′(x)Outer deriv 1/u1/u1/u × inner deriv → g′(x)/g(x)g'(x)/g(x)g′(x)/g(x).
  • TypeQuadratic inside
  • Form(ax2+bx+c)n(ax^2+bx+c)^n(ax2+bx+c)n etc.
  • How to get f′(x)f^{\prime}(x)f′(x)Inner deriv 2ax+b2ax+b2ax+b; multiply by outer deriv.
TypeFormHow to get f′(x)f^{\prime}(x)f′(x)
Power(g(x))n(g(x))^n(g(x))nOuter deriv nun−1n u^{n-1}nun−1 × inner deriv g′(x)g'(x)g′(x).
Exponentialeg(x)e^{g(x)}eg(x)Outer deriv eue^ueu × inner deriv → eg(x)⋅g′(x)e^{g(x)} \cdot g'(x)eg(x)⋅g′(x).
Trigsin⁡(g(x))\sin(g(x))sin(g(x)), cos⁡(g(x))\cos(g(x))cos(g(x))Outer deriv (cos or −sin) × inner deriv.
Rootg(x)\sqrt{g(x)}g(x)​Outer deriv 1/(2u)1/(2\sqrt{u})1/(2u​) × inner deriv.
Logln⁡(g(x))\ln(g(x))ln(g(x))Outer deriv 1/u1/u1/u × inner deriv → g′(x)/g(x)g'(x)/g(x)g′(x)/g(x).
Quadratic inside(ax2+bx+c)n(ax^2+bx+c)^n(ax2+bx+c)n etc.Inner deriv 2ax+b2ax+b2ax+b; multiply by outer deriv.

Example (power)
For y=(3x)2y=(3x)^2y=(3x)2, find the derivative at x=2x=2x=2.
Solution
y′=2⋅3x⋅3=18xy'=2\cdot 3x \cdot 3=18xy′=2⋅3x⋅3=18x. At x=2x=2x=2 → 363636. → Answer 36

Example (exponential)
For y=e3xy=e^{3x}y=e3x, find the derivative at x=0x=0x=0.
Solution
y′=3e3xy'=3e^{3x}y′=3e3x. At x=0x=0x=0 → 3e0=33e^0=33e0=3. → Answer 3

Example (trig)
For y=sin⁡(2x)y=\sin(2x)y=sin(2x), find the derivative at x=0x=0x=0.
Solution
y′=2cos⁡(2x)y'=2\cos(2x)y′=2cos(2x). At x=0x=0x=0 → 2cos⁡0=22\cos 0=22cos0=2. → Answer 2

Example (log)
For y=ln⁡(sin⁡x)y=\ln(\sin x)y=ln(sinx), find the derivative at x=π/2x=\pi/2x=π/2.
Solution
y′=cos⁡xsin⁡x=cot⁡xy'=\frac{\cos x}{\sin x}=\cot xy′=sinxcosx​=cotx. At x=π/2x=\pi/2x=π/2, cos⁡(π/2)=0\cos(\pi/2)=0cos(π/2)=0 so y′=0y'=0y′=0. → Answer 0