Computational graphs and automatic differentiation

Understanding automatic differentiation through computational graphs
ML
Author

Vannsh Jani

Published

September 26, 2023

Computational graphs

Computational graphs is a method to represent complex mathematical expressions in simpler forms creating intermediate variables which simplify the expression. Given input values to the expression, computational graphs can calculate the final value of the graph by traversing the graph in the forward direction. They also provide a systematic way to compute derivatives of the final output of the graph with respect to the inputs and the intermediate variables by traversing the graph in the reverse direction. Hence, computational graphs can be useful in calculating gradient of a function and also help us understand automatic differentiation.

Let’s better understand computational graphs with an example. Suppose we want to calculate the expression \(\text f(\text x,\text y)=\log(\text x^2 +\sin(3\text x \text y))=\text J\).Here we are considering natural logarithm. We can create the intermediate variables as follows.

\[\begin{equation}\begin{split} \text x,\text y &= \hspace{0.1cm} \text{inputs} \\ \text u &= \hspace{0.1cm} 3\text x\text y \\ \text v &= \hspace{0.1cm} \sin u \\ \text z &= \hspace{0.1cm} \text v+\text x^2 \\ \text J &= \hspace{0.1cm} \log z \\ \text J &= \hspace{0.1cm} \text{output} \end{split}\end{equation}\]

Following is the computational graph of the above function \(\text f\).

Code
import graphviz

graph = graphviz.Graph(graph_attr={'rankdir': 'LR'})

with graph.subgraph() as s:
    s.node('x')
    s.node('u=3xy',shape = 'rectangle')
    s.node('v=sin(u)',shape = 'rectangle')
    s.node('z=v+(x^2)',shape = 'rectangle')
    s.node('J=log(z)',shape = 'rectangle')
    s.node('output')


graph.node('y')

graph.edge('x', 'u=3xy', dir='forward')
graph.edge('y', 'u=3xy', dir='forward')
graph.edge('u=3xy', 'v=sin(u)', dir='forward')
graph.edge('v=sin(u)', 'z=v+(x^2)', dir='forward')
graph.edge('z=v+(x^2)', 'J=log(z)', dir='forward')
graph.edge('J=log(z)','output',dir='forward')
graph.edge('x', 'z=v+(x^2)', dir='forward')


graph

Let’s take an example and compute the output by traversing the graph in the forward direction. We will take \(\text x =\hspace{0.1cm} 2 \hspace{0.1cm} \text{and} \hspace{0.1cm} \text y = \hspace{0.1cm}3\).

so \(\text f(2,3)=\hspace{0.1cm} 1.178\) rounded off to 3 decimal places. We can verify the same using computational graphs.

Code
graph = graphviz.Graph(graph_attr={'rankdir': 'LR'})

with graph.subgraph() as s:
    s.node('x')
    s.node('u=3xy',shape = 'rectangle')
    s.node('v=sin(u)',shape = 'rectangle')
    s.node('z=v+(x^2)',shape = 'rectangle')
    s.node('J=log(z)',shape = 'rectangle')
    s.node('output')


graph.node('y')

graph.edge('x', 'u=3xy', dir='forward',label='2',fontcolor='green')
graph.edge('y', 'u=3xy', dir='forward',label='3',fontcolor='green')
graph.edge('u=3xy', 'v=sin(u)', dir='forward',label='18',fontcolor='green')
graph.edge('v=sin(u)', 'z=v+(x^2)', dir='forward',label='-0.75',fontcolor='green')
graph.edge('x', 'z=v+(x^2)', dir='forward',label='2',fontcolor='green')
graph.edge('z=v+(x^2)', 'J=log(z)', dir='forward',label='3.249',fontcolor='green')
graph.edge('J=log(z)','output',dir='forward',label='1.178',fontcolor='green')



graph

Computational graphs can be used to calculate derivatives of final output of graph with respect to the inputs using the derivatives of the output with respect to the intermediate and the derivatives of the intermediate with respect to the inputs with the help of the chain rule. This is helpful in calculating the gradient of the output in the backpropagation algorithm which is used to train a neural network.

We can define \(\frac{\partial \text f}{\partial \text x}\) as how much more is the change in \(\text f\) when there is a small change in \(\text x\) where \(\text f\) is a function of \(\text x\) and possibly more variables. For our example let’s consider the small change in \(\text x\) as \(\Delta \text x=\hspace{0.1cm}0.00001\).

So for example if we change \(\text x\) by \(\Delta \text x\) and if \(\text f\) changes by let’s say \(0.00004\), then \(\text f\) has changed it’s value \(4\) times more than \(\text x\) and hence, the value of derivative of \(\text f\) with respect to \(\text x\) is \(4\).

For our above example let’s try to compute the gradient of \(\text f\) i.e \(\nabla \text f = \hspace{0.1cm} \begin{bmatrix} \frac{\partial \text f}{\partial \text x} \\ \frac{\partial \text f}{\partial \text y} \end{bmatrix}\).

Let’s change \(\text z\) by \(\Delta \text x\). Hence \(\text z\) is \(3.24901\). \(\text J\) changes by \(3.07\text x10^{-6}\) and hence change in \(\text J\) by change in \(\text z\) is \(0.307\) approximately which is the derivative of \(\text f\) with respect to \(\text z\). Ignoring the round offs this value is equal to \(\frac{1}{\text z}\).

Hence \(\frac{\partial \text J}{\partial \text z}=\hspace{0.1cm} \frac{1}{\text z}=\hspace{0.1cm}{0.307}\)

Code
graph = graphviz.Graph(graph_attr={'rankdir': 'LR'})

with graph.subgraph() as s:
    s.node('x')
    s.node('u=3xy',shape = 'rectangle')
    s.node('v=sin(u)',shape = 'rectangle')
    s.node('z=v+(x^2)',shape = 'rectangle')
    s.node('J=log(z)',shape = 'rectangle')
    s.node('output')


graph.node('y')

graph.edge('x', 'u=3xy', dir='forward',label='2',fontcolor='green')
graph.edge('y', 'u=3xy', dir='forward',label='3',fontcolor='green')
graph.edge('u=3xy', 'v=sin(u)', dir='forward',label='18',fontcolor='green')
graph.edge('v=sin(u)', 'z=v+(x^2)', dir='forward',label='-0.75',fontcolor='green')
graph.edge('x', 'z=v+(x^2)', dir='forward',label='2',fontcolor='green')
graph.edge('z=v+(x^2)', 'J=log(z)', dir='forward',label='3.249',fontcolor='green')
graph.edge('J=log(z)', 'z=v+(x^2)', dir='forward',label='0.307',fontcolor='red')
graph.edge('J=log(z)','output',dir='forward',label='1.178',fontcolor='green')



graph

similarly we can calculate \(\frac{\partial \text z}{\partial \text v}\) and that comes out to be \(1\) as \(\text z\) increases by the same amount as \(\text v\), and \(\frac{\partial \text v}{\partial \text u}\) and that comes out to be \(0.66\) which is \(\cos u\).

Code
graph = graphviz.Graph(graph_attr={'rankdir': 'LR'})
with graph.subgraph() as s:
    s.node('x')
    s.node('u=3xy',shape = 'rectangle')
    s.node('v=sin(u)',shape = 'rectangle')
    s.node('z=v+(x^2)',shape = 'rectangle')
    s.node('J=log(z)',shape = 'rectangle')
    s.node('output')


graph.node('y')

graph.edge('x', 'u=3xy', dir='forward',label='2',fontcolor='green')
graph.edge('y', 'u=3xy', dir='forward',label='3',fontcolor='green')
graph.edge('u=3xy', 'v=sin(u)', dir='forward',label='18',fontcolor='green')
graph.edge('v=sin(u)', 'z=v+(x^2)', dir='forward',label='-0.75',fontcolor='green')
graph.edge('z=v+(x^2)','v=sin(u)', dir='forward',label='1',fontcolor='red')
graph.edge('x', 'z=v+(x^2)', dir='forward',label='2',fontcolor='green')
# graph.edge('z=v+(x^2)','x', dir='forward',label='25.0074',fontcolor='red')
graph.edge('z=v+(x^2)', 'J=log(z)', dir='forward',label='3.249',fontcolor='green')
graph.edge('J=log(z)', 'z=v+(x^2)', dir='forward',label='0.307',fontcolor='red')
graph.edge('J=log(z)','output',dir='forward',label='1.178',fontcolor='green')
graph.edge('v=sin(u)','u=3xy',dir='forward',label='0.66',fontcolor='red')



graph

In the above graph, the red labels indicate the the derivative of the variable in the right box with respect to the variable in the left box. Hence, we know that,

\[\begin{equation}\begin{split} \frac{\partial \text J}{\partial \text z}&=\hspace{0.1cm}0.307 \\ \frac{\partial \text z}{\partial \text v}&=\hspace{0.1cm}1 \\ \frac{\partial \text v}{\partial \text u}&=\hspace{0.1cm}0.66 \end{split}\end{equation}\]

From chain rule, we can say that,

\[\begin{equation}\begin{split}\frac{\partial \text J}{\partial \text v}&=\hspace{0.1cm}\frac{\partial \text J}{\partial \text z}.\frac{\partial \text z}{\partial \text v} \\ \frac{\partial \text J}{\partial \text v} &=\hspace{0.1cm} 0.307 \\ \frac{\partial \text J}{\partial \text u}&=\hspace{0.1cm}\frac{\partial \text J}{\partial \text v}.\frac{\partial \text v}{\partial \text u} \\ \frac{\partial \text J}{\partial \text u} &=\hspace{0.1cm} 0.202 \end{split}\end{equation} \]

If we change \(\text y\) by a small amount \(0.00001\) i.e from \(3\) to \(3.00001\) then \(\text u\) changes by \(0.00006\) and hence the derivative of \(\text u\) with respect to \(\text y\) is equal to 6 which is equal to \(3\text x\). Similarly derivative of \(\text u\) with respect to \(\text x\) is \(9\) which is \(3\text y\). The derivative of \(\text z\) with respect to \(\text x\) is equal to \(4\), as if \(\text x\) changes by \(0.00001\) then \(\text z\) changes by \(0.00004\). These values are rounded off to two or three decimal places.

Hence,

Code
graph = graphviz.Graph(graph_attr={'rankdir': 'LR'})
graph.node('x')
graph.node('y')
graph.node('u=3xy',shape = 'rectangle')
graph.node('v=sin(u)',shape = 'rectangle')
graph.node('z=v+(x^2)',shape = 'rectangle')
graph.node('J=log(z)',shape = 'rectangle')
graph.node('output')

graph.edge('z=v+(x^2)','x', dir='forward',label='4',fontcolor='red')
graph.edge('y','u=3xy',dir='forward',label='3',fontcolor='green')
graph.edge('u=3xy','y',dir='forward',label='6',fontcolor='red',constraint='false')
graph.edge('x','u=3xy',dir='forward',label='2',fontcolor='green')
graph.edge('u=3xy','x',dir='forward',label='9',fontcolor='red',constraint='false')
graph.edge('u=3xy','v=sin(u)',dir='forward',label='18',fontcolor='green')
graph.edge('v=sin(u)','z=v+(x^2)',dir='forward',label='-0.75',fontcolor='green')
graph.edge('x','z=v+(x^2)',dir='forward',label='2',fontcolor='green')
graph.edge('z=v+(x^2)','J=log(z)',dir='forward',label='3.249',fontcolor='green')
graph.edge('J=log(z)','output',dir='forward',label='1.178',fontcolor='green')
graph.edge('z=v+(x^2)','v=sin(u)', dir='forward',label='1',fontcolor='red')
graph.edge('J=log(z)', 'z=v+(x^2)', dir='forward',label='0.307',fontcolor='red')
graph.edge('v=sin(u)','u=3xy',dir='forward',label='0.66',fontcolor='red')



graph

\[ \begin{equation}\begin{split} \frac{\partial \text J}{\partial \text y}&=\hspace{0.1cm}\frac{\partial \text J}{\partial \text u}.\frac{\partial \text u}{\partial \text y} \\ \frac{\partial \text J}{\partial \text y}&=\hspace{0.1cm}1.22 \\ \frac{\partial \text J}{\partial \text x}&=\hspace{0.1cm}\frac{\partial \text u}{\partial \text x}.\frac{\partial \text J}{\partial \text u} + \frac{\partial \text J}{\partial \text z}.\frac{\partial \text z}{\partial \text x} \\ \frac{\partial \text J}{\partial \text x}&=\hspace{0.1cm}3.05 \end{split}\end{equation}\]

Hence, if \(\text x=\hspace{0.1cm}2\) and \(\text y=\hspace{0.1cm}3\), then \(\frac{\partial \text J}{\partial \text y}=\hspace{0.1cm}1.22 \hspace{0.2cm} \text{and}\hspace{0.1cm}\frac{\partial \text J}{\partial \text x}=\hspace{0.1cm}3.05\) approximately, rounded off to two decimal places.

Automatic differentiation in python

When we train neural netwworks for practical applications, they sometimes have millions of parameters and we need to calculate the derivative of the cost function with respect to each of these parameters with the help of intermediate variables and chain rule as seen above. So with the help of computer programs we can calculate gradient of the cost function efficiently.

There are mainly two types of automatic differentiations, namely:

  1. Forward mode
  2. Reverse mode

Forward mode

In the forward mode we calculate the values of the intermediate variables also called the primals as we traverse through the graph, but simultaneously we also calculate the values of the derivatives(tangents) of the primals with respect to the input variables. We have to run a seperate forward pass for each input variable.

Let’s perform automatic differentiation (forward mode) in python from scratch

Code
class Pair:
    def __init__(self, val, der,var):
        self.val = val
        self.der = der
        self.var = var

    def __add__(self, other):
        if isinstance(other, Pair):
            return Pair(self.val + other.val, self.der + other.der,self.var)
        else:
            return Pair(self.val + other, self.der,self.var)

    def __sub__(self, other):
        if isinstance(other, Pair):
            return Pair(self.val - other.val, self.der - other.der,self.var)
        else:
            return Pair(self.val - other, self.der,self.var)

    def __mul__(self, other):
        if isinstance(other, Pair):
            return Pair(self.val * other.val,  other.val,self.var)
        else:
            return Pair(self.val * other, self.der * other,self.var)

    def __truediv__(self, other):
        if isinstance(other, Pair):
            return Pair(self.val / other.val, (self.der * other.val - self.val * other.der) / (other.val ** 2),self.var)
        else:
            return Pair(self.val / other, self.der / other,self.var)

    def sin(self):
        return Pair(np.sin(self.val), self.der * np.cos(self.val),self.var)

    def log(self):
        return Pair(np.log(self.val), self.der / self.val,self.var)

    def __pow__(self, power):
        return Pair(self.val ** power, power * self.val ** (power - 1) * self.der,self.var)

    def __repr__(self):
        return f"Pair of function value and derivative with respect to {self.var} is ({round(self.val,3)}, {round(self.der,2)})"

pairx = Pair(2,1,"x")
pairy = Pair(3,1,"y")

def forward_pass_x(pairx,fv):
    u = (pairx*fv)*3
    v = u.sin()
    z = v + pairx**2
    J = z.log()
    print(J)

def forward_pass_y(pairy,fv):
    u = (pairy*fv)*3
    v = u.sin()
    z = v + fv**2
    J = z.log()
    print(J)
    
forward_pass_x(pairx,3)
forward_pass_y(pairy,2)
Pair of function value and derivative with respect to x is (1.178, 3.06)
Pair of function value and derivative with respect to y is (1.178, 1.22)

Reverse mode

Reverse mode autodiff is the same process as what we did in the above example in computation graphs. We forward pass through the computation graph calculating the final output and then we reverse pass through the graph calculating the derivative of the final output with respect to the intermediate variables and then using chain rule to calculate the derivative of the final output with respect to the input variables. Unlike forward mode autodiff we can calulate the gradient of the output function in just one reverse paas through the computation graph.

Let’s implement reverse mode automatic differentiation from scratch and then verify it using pytorch.

Code
import numpy as np
class Scalar:    
    def __init__(self, val):
        self.val = val
        self.grad = 0.
        self.backward = lambda: None
        
    def __repr__(self):
        return f"Value: {self.val}, Gradient: {self.grad}"
    
    def __add__(self, other):
        if isinstance(other,Scalar):
            result = Scalar(self.val + other.val)
            def backward():
                self.grad += result.grad
                other.grad += result.grad
                self.backward()
                other.backward()
            result.backward = backward
        else:
            result = Scalar(self.val+other)
            def backward():
                self.grad += result.grad
                self.backward()
            result.backward = backward
        return result

    def __mul__(self, other):
        if isinstance(other,Scalar):
            result = Scalar(self.val * other.val)
            def backward():
                self.grad += other.val * result.grad
                other.grad += self.val * result.grad
                self.backward()
                other.backward()
            result.backward = backward
        else:
            result = Scalar(self.val*other)
            def backward():
                self.grad += other * result.grad
                self.backward()
            result.backward = backward
        return result

    def sin(self):
        result = Scalar(np.sin(self.val))
        def backward():
            self.grad += np.cos(self.val)*result.grad
            self.backward()
        result.backward = backward
        return result
    def log(self):
        result = Scalar(np.log(self.val))
        def backward():
            self.grad += (1/(self.val))*result.grad
            self.backward()
        result.backward = backward
        return result
x = Scalar(2.0)
y = Scalar(3.0)
u = x*y*3
v = u.sin()
z = v + (x*x)
J = z.log()
J.grad = 1.
J.backward()
print("Rounding off the derivatives to two decimal places")
print("Derivative of J wrt x",round(x.grad,2))
print("Derivative of J wrt y",round(y.grad,2))
Rounding off the derivatives to two decimal places
Derivative of J wrt x 3.06
Derivative of J wrt y 1.22

Let’s now use pytorch

Code
import torch

x = torch.tensor(2.0,requires_grad=True)
y = torch.tensor(3.0,requires_grad=True)

def J(x,y):
    u = 3*x*y
    v = torch.sin(u)
    z = v + x**2
    return torch.log(z),u,v,z
output,u,v,z = J(x,y)
print("derivative of output wrt u is",torch.autograd.grad(output,u,retain_graph=True)[0].item())
print("derivative of output wrt v is",torch.autograd.grad(output,v,retain_graph=True)[0].item())
print("derivative of output wrt z is",torch.autograd.grad(output,z,retain_graph=True)[0].item())
print()
print("output is",round(output.item(),3),"rounded off to three decimal places")
print("derivative of output wrt x is",round(torch.autograd.grad(output,x,retain_graph=True)[0].item(),2),"rounded off to two decimal places")
print("derivative of output wrt y is",round(torch.autograd.grad(output,y)[0].item(),2),"rounded off to two decimal places")
derivative of output wrt u is 0.20323611795902252
derivative of output wrt v is 0.3077858090400696
derivative of output wrt z is 0.3077858090400696

output is 1.178 rounded off to three decimal places
derivative of output wrt x is 3.06 rounded off to two decimal places
derivative of output wrt y is 1.22 rounded off to two decimal places

The reason automatic differentiation is faster than let’s say symbolic differentiation although they have similar accuracy is because the sole purpose of autodiff is to calculate the numerical value of the derivative by performing underlying primitive operations whose derivative we know for eg derivative of \(\sin x\) is \(\cos x\). We simplify the function with the help of different intermediate variables only consisting such primitive operations whose derivative we know. In symbolic differentiation there is the possibility of repeated calculations which lowers it’s efficiency.