Computational graphs and automatic differentiation
Understanding automatic differentiation through computational graphs
ML
Author
Vannsh Jani
Published
September 26, 2023
Computational graphs
Computational graphs is a method to represent complex mathematical expressions in simpler forms creating intermediate variables which simplify the expression. Given input values to the expression, computational graphs can calculate the final value of the graph by traversing the graph in the forward direction. They also provide a systematic way to compute derivatives of the final output of the graph with respect to the inputs and the intermediate variables by traversing the graph in the reverse direction. Hence, computational graphs can be useful in calculating gradient of a function and also help us understand automatic differentiation.
Let’s better understand computational graphs with an example. Suppose we want to calculate the expression \(\text f(\text x,\text y)=\log(\text x^2 +\sin(3\text x \text y))=\text J\).Here we are considering natural logarithm. We can create the intermediate variables as follows.
\[\begin{equation}\begin{split} \text x,\text y &= \hspace{0.1cm} \text{inputs} \\ \text u &= \hspace{0.1cm} 3\text x\text y \\ \text v &= \hspace{0.1cm} \sin u \\ \text z &= \hspace{0.1cm} \text v+\text x^2 \\ \text J &= \hspace{0.1cm} \log z \\ \text J &= \hspace{0.1cm} \text{output} \end{split}\end{equation}\]
Following is the computational graph of the above function \(\text f\).
Let’s take an example and compute the output by traversing the graph in the forward direction. We will take \(\text x =\hspace{0.1cm} 2 \hspace{0.1cm} \text{and} \hspace{0.1cm} \text y = \hspace{0.1cm}3\).
so \(\text f(2,3)=\hspace{0.1cm} 1.178\) rounded off to 3 decimal places. We can verify the same using computational graphs.
Computational graphs can be used to calculate derivatives of final output of graph with respect to the inputs using the derivatives of the output with respect to the intermediate and the derivatives of the intermediate with respect to the inputs with the help of the chain rule. This is helpful in calculating the gradient of the output in the backpropagation algorithm which is used to train a neural network.
We can define \(\frac{\partial \text f}{\partial \text x}\) as how much more is the change in \(\text f\) when there is a small change in \(\text x\) where \(\text f\) is a function of \(\text x\) and possibly more variables. For our example let’s consider the small change in \(\text x\) as \(\Delta \text x=\hspace{0.1cm}0.00001\).
So for example if we change \(\text x\) by \(\Delta \text x\) and if \(\text f\) changes by let’s say \(0.00004\), then \(\text f\) has changed it’s value \(4\) times more than \(\text x\) and hence, the value of derivative of \(\text f\) with respect to \(\text x\) is \(4\).
For our above example let’s try to compute the gradient of \(\text f\) i.e \(\nabla \text f = \hspace{0.1cm} \begin{bmatrix} \frac{\partial \text f}{\partial \text x} \\ \frac{\partial \text f}{\partial \text y} \end{bmatrix}\).
Let’s change \(\text z\) by \(\Delta \text x\). Hence \(\text z\) is \(3.24901\). \(\text J\) changes by \(3.07\text x10^{-6}\) and hence change in \(\text J\) by change in \(\text z\) is \(0.307\) approximately which is the derivative of \(\text f\) with respect to \(\text z\). Ignoring the round offs this value is equal to \(\frac{1}{\text z}\).
similarly we can calculate \(\frac{\partial \text z}{\partial \text v}\) and that comes out to be \(1\) as \(\text z\) increases by the same amount as \(\text v\), and \(\frac{\partial \text v}{\partial \text u}\) and that comes out to be \(0.66\) which is \(\cos u\).
In the above graph, the red labels indicate the the derivative of the variable in the right box with respect to the variable in the left box. Hence, we know that,
If we change \(\text y\) by a small amount \(0.00001\) i.e from \(3\) to \(3.00001\) then \(\text u\) changes by \(0.00006\) and hence the derivative of \(\text u\) with respect to \(\text y\) is equal to 6 which is equal to \(3\text x\). Similarly derivative of \(\text u\) with respect to \(\text x\) is \(9\) which is \(3\text y\). The derivative of \(\text z\) with respect to \(\text x\) is equal to \(4\), as if \(\text x\) changes by \(0.00001\) then \(\text z\) changes by \(0.00004\). These values are rounded off to two or three decimal places.
Hence, if \(\text x=\hspace{0.1cm}2\) and \(\text y=\hspace{0.1cm}3\), then \(\frac{\partial \text J}{\partial \text y}=\hspace{0.1cm}1.22 \hspace{0.2cm} \text{and}\hspace{0.1cm}\frac{\partial \text J}{\partial \text x}=\hspace{0.1cm}3.05\) approximately, rounded off to two decimal places.
Automatic differentiation in python
When we train neural netwworks for practical applications, they sometimes have millions of parameters and we need to calculate the derivative of the cost function with respect to each of these parameters with the help of intermediate variables and chain rule as seen above. So with the help of computer programs we can calculate gradient of the cost function efficiently.
There are mainly two types of automatic differentiations, namely:
Forward mode
Reverse mode
Forward mode
In the forward mode we calculate the values of the intermediate variables also called the primals as we traverse through the graph, but simultaneously we also calculate the values of the derivatives(tangents) of the primals with respect to the input variables. We have to run a seperate forward pass for each input variable.
Let’s perform automatic differentiation (forward mode) in python from scratch
Code
class Pair:def__init__(self, val, der,var):self.val = valself.der = derself.var = vardef__add__(self, other):ifisinstance(other, Pair):return Pair(self.val + other.val, self.der + other.der,self.var)else:return Pair(self.val + other, self.der,self.var)def__sub__(self, other):ifisinstance(other, Pair):return Pair(self.val - other.val, self.der - other.der,self.var)else:return Pair(self.val - other, self.der,self.var)def__mul__(self, other):ifisinstance(other, Pair):return Pair(self.val * other.val, other.val,self.var)else:return Pair(self.val * other, self.der * other,self.var)def__truediv__(self, other):ifisinstance(other, Pair):return Pair(self.val / other.val, (self.der * other.val -self.val * other.der) / (other.val **2),self.var)else:return Pair(self.val / other, self.der / other,self.var)def sin(self):return Pair(np.sin(self.val), self.der * np.cos(self.val),self.var)def log(self):return Pair(np.log(self.val), self.der /self.val,self.var)def__pow__(self, power):return Pair(self.val ** power, power *self.val ** (power -1) *self.der,self.var)def__repr__(self):returnf"Pair of function value and derivative with respect to {self.var} is ({round(self.val,3)}, {round(self.der,2)})"pairx = Pair(2,1,"x")pairy = Pair(3,1,"y")def forward_pass_x(pairx,fv): u = (pairx*fv)*3 v = u.sin() z = v + pairx**2 J = z.log()print(J)def forward_pass_y(pairy,fv): u = (pairy*fv)*3 v = u.sin() z = v + fv**2 J = z.log()print(J)forward_pass_x(pairx,3)forward_pass_y(pairy,2)
Pair of function value and derivative with respect to x is (1.178, 3.06)
Pair of function value and derivative with respect to y is (1.178, 1.22)
Reverse mode
Reverse mode autodiff is the same process as what we did in the above example in computation graphs. We forward pass through the computation graph calculating the final output and then we reverse pass through the graph calculating the derivative of the final output with respect to the intermediate variables and then using chain rule to calculate the derivative of the final output with respect to the input variables. Unlike forward mode autodiff we can calulate the gradient of the output function in just one reverse paas through the computation graph.
Let’s implement reverse mode automatic differentiation from scratch and then verify it using pytorch.
Code
import numpy as npclass Scalar: def__init__(self, val):self.val = valself.grad =0.self.backward =lambda: Nonedef__repr__(self):returnf"Value: {self.val}, Gradient: {self.grad}"def__add__(self, other):ifisinstance(other,Scalar): result = Scalar(self.val + other.val)def backward():self.grad += result.grad other.grad += result.gradself.backward() other.backward() result.backward = backwardelse: result = Scalar(self.val+other)def backward():self.grad += result.gradself.backward() result.backward = backwardreturn resultdef__mul__(self, other):ifisinstance(other,Scalar): result = Scalar(self.val * other.val)def backward():self.grad += other.val * result.grad other.grad +=self.val * result.gradself.backward() other.backward() result.backward = backwardelse: result = Scalar(self.val*other)def backward():self.grad += other * result.gradself.backward() result.backward = backwardreturn resultdef sin(self): result = Scalar(np.sin(self.val))def backward():self.grad += np.cos(self.val)*result.gradself.backward() result.backward = backwardreturn resultdef log(self): result = Scalar(np.log(self.val))def backward():self.grad += (1/(self.val))*result.gradself.backward() result.backward = backwardreturn resultx = Scalar(2.0)y = Scalar(3.0)u = x*y*3v = u.sin()z = v + (x*x)J = z.log()J.grad =1.J.backward()print("Rounding off the derivatives to two decimal places")print("Derivative of J wrt x",round(x.grad,2))print("Derivative of J wrt y",round(y.grad,2))
Rounding off the derivatives to two decimal places
Derivative of J wrt x 3.06
Derivative of J wrt y 1.22
Let’s now use pytorch
Code
import torchx = torch.tensor(2.0,requires_grad=True)y = torch.tensor(3.0,requires_grad=True)def J(x,y): u =3*x*y v = torch.sin(u) z = v + x**2return torch.log(z),u,v,zoutput,u,v,z = J(x,y)print("derivative of output wrt u is",torch.autograd.grad(output,u,retain_graph=True)[0].item())print("derivative of output wrt v is",torch.autograd.grad(output,v,retain_graph=True)[0].item())print("derivative of output wrt z is",torch.autograd.grad(output,z,retain_graph=True)[0].item())print()print("output is",round(output.item(),3),"rounded off to three decimal places")print("derivative of output wrt x is",round(torch.autograd.grad(output,x,retain_graph=True)[0].item(),2),"rounded off to two decimal places")print("derivative of output wrt y is",round(torch.autograd.grad(output,y)[0].item(),2),"rounded off to two decimal places")
derivative of output wrt u is 0.20323611795902252
derivative of output wrt v is 0.3077858090400696
derivative of output wrt z is 0.3077858090400696
output is 1.178 rounded off to three decimal places
derivative of output wrt x is 3.06 rounded off to two decimal places
derivative of output wrt y is 1.22 rounded off to two decimal places
The reason automatic differentiation is faster than let’s say symbolic differentiation although they have similar accuracy is because the sole purpose of autodiff is to calculate the numerical value of the derivative by performing underlying primitive operations whose derivative we know for eg derivative of \(\sin x\) is \(\cos x\). We simplify the function with the help of different intermediate variables only consisting such primitive operations whose derivative we know. In symbolic differentiation there is the possibility of repeated calculations which lowers it’s efficiency.