Equivariant Neural Networks

In this post, we will talk about the mathematical concept of equivariance and its use in building neural networks that respect symmetry.

What is equivariance

Given a group GG that acts on sets EE and FF, a function f:EFf: E \to F is GG-equivariant if

f(ge)=gf(e)for alleE. f(g\cdot e) = g\cdot f(e)\quad \text{for all}\quad e\in E.

Basically, function ff is GG-equivariant when it preserves the symmetry imposed by the group GG.

Why do we need equivariant neural networks

The task of the feedforward neural networks is to transform the input data into output data. That is, they are functions from the input space to the output space. We might want the neural networks to be equivariant under the assumption that data have symmetry informtation we want to preserve through the tranformation.

An example one can think of is that of image classification. When classifying images, the labels we give to the images are usually rotationally invariant. Hence, if the classifier neural network is equivariant to rotation, one can potentially save significant amount of computational power and make the model more generalizable.

Equivariant kernels for CNN

To build an equivariant network, we can take the framwork of the convolutional neural network and make the kernels/filters equivariant.

Suppose our data are vectors indexed by set EE (or functions from EE to C\mathbb C) acted by group GG. Then group GG also acts on our input data space CE\mathbb C^E via

(gf)(x)=f(gx). (g\cdot f)(x) = f(g\cdot x).

To transform the data GG-equivariantly, we set a kernel function h:ECh: E\to \mathbb C with a pole ηE\eta\in E and define the group convolution

fG,ηh(x):=Gf(gη)h(g1x)dμ(g), f*_{G,\eta}h(x):= \int_G f(g\cdot \eta)h(g^{-1}\cdot x)d\mu(g),

where μ\mu is the Haar measure of the group.

One can check that the operation

Covh,η:ffG,ηh \text{Cov}_{h,\eta}: f\mapsto f*_{G,\eta}h

is GG-equivariant.

SO(3)SO(3)-equivariance with spherical harmonics

Now let's look at an implementation of SO(3)SO(3)-equivariant network from this paper. Here, we consider spherical functions in L2(S2)L^2(\mathbb{S}^2) as data. Under the integrability assumption, we can perform the spherical Fourier transform (SFT)

f=l=0m=fmYm,fm=f,YmL2(S2)=S2f(x)Ym(x)dS2.\begin{aligned} f &= \sum_{l=0}^\infty\sum_{m=-\ell}^\ell f^m_\ell Y^m_\ell,\\ f^m_\ell &= \langle f,Y^m_\ell\rangle_{L^2(\mathbb{S}^2)} = \int_{\mathbb{S}^2}f(x)\overline{Y^m_\ell}(x)d\mathbb{S}^2. \end{aligned}

Here, {Ym}\{Y^m_\ell\} are the spherical harmonics where \ell is the degree of the corresponding homogeunous polynomials and mm is the order. Given a kernel function hL2(S2)h\in L^2(\mathbb{S}^2) and the north pole N=(0,0,1)S2N = (0,0,1)\in \mathbb{S}^2, we can write (3) as

fh(x)=SO(3)f(gN)h(g1x)dμ(g). f * h (x) = \int_{SO(3)} f(g\cdot N)h(g^{-1}\cdot x)d\mu(g).

using the fact that the Haar measure μ\mu is essentially a multiple of the uniform measure on S2\mathbb{S}^2, we can derive

(fh)m=16π32+2fmh0. (f*h)^m_\ell = \sqrt{\frac{16\pi^3}{2\ell+2}}f^m_\ell h^0_\ell.

The key observation here is that the information we need for the kernel hh is only the 0-th order coefficients {h0}=0\{h^{\ell}_0\}_{\ell=0}^\infty.

Implementation of convolution

Suppose that ff has bandwidth b>0b>0. That is, fm=0f^m_\ell = 0 for all b\ell\geq b. Then by (7), fhf*h will also have bandwidth bb and thus it is suffice to keep track of (h00, ,h0b1)(h^0_0,\cdots,h^{b-1}_0).

class SO3_Conv(nn.Module):
    def __init__(self, bandwidth):
        super().__init__()
        self.bandwidth = bandwidth
        h0 = torch.Tensor(bandwidth, 1)
        self.h0 = nn.Parameter(h0)
        nn.init.kaiming_uniform_(self.h0)

    def forward(self, x):
        x = spherical_fourier_transform(x, self.bandwidth)
        weights = torch.sqrt(16*math.pi**3/torch.arange(2, 2*b+2, 2)) * self.h0
        return inverse_spherical_fourier_transform(x*weights)

Due to finite bandwidth, we can also calculate the Fourier coefficients based on only (2b+1)2(2b+1)^2 equi-angular sample points on S2\mathbb{S}^2:

fm=2π2bj=02b1k=02b1wj(b)f(xj,k)Ym(xj,k),xj,k=(cos(jπ/b)sin(kπ/b),sin(jπ/b)cos(kπ/b),cos(jπ/b)),\begin{aligned} f^m_\ell &= \frac{\sqrt{2\pi}}{2b}\sum_{j=0}^{2b-1}\sum_{k=0}^{2b-1} w^{(b)}_jf(x_{j,k})\overline{Y^m_\ell}(x_{j,k}),\\ x_{j,k}&= (\cos(j\pi/b)\sin(k\pi/b),\sin(j\pi/b)\cos(k\pi/b),\cos(j\pi/b)), \end{aligned}

where wj(b)w^{(b)}_j are predetermined weights on {xj,k}\{x_{j,k}\}. Hence, to implement Covh,η\text{Cov}_{h,\eta}, we can first find coefficients fmf^m_\ell with (8). Then we find apply the pointwise product by (7) to get the coefficients for Covh,η(f)\text{Cov}_{h,\eta}(f).

Non-linearity

In practice, the nonlinear layer is done by the standard pointwise operation:

NLσ:fσf. \text{NL}_\sigma: f \mapsto \sigma \circ f.

One can easily check that NLσ\text{NL}_\sigma is equivariant.

Warning!
Operation NLσ\text{NL}_\sigmadoes not preserve the bandwith of the data. In fact, NLσ(f)\text{NL}_\sigma(f) can have infinite bandwidth regardless of the bandwith of ff. Therefore, computing the Fourier coefficients with (8) after a non-linearity operation will introduce errors (See equivariance error analysis).

Spectral pooling

Here we introduced a pooling layer that acts as a low-pass filter with cutoff frequency b/2b/2. In practice, we can simply set fmf^m_\ell to be zero for all b/2<<bb/2<\ell < b.

Invariant descriptor

In tasks such as image classification, the output is invariant to SO(3)SO(3) actions (equivalently, SO(3)SO(3) acts trivially on the output space). Therefore, we would like the output to be SO(3)SO(3)-invariant. One way to achieve this is to use the following operation to produce an output vector:

Des:f(f0,,fb1)f=(f,f+1,,f).\begin{aligned} \text{Des}: f &\mapsto (||\mathbf{f}^0||,\dots,||\mathbf{f}^{b-1}||)\\ \mathbf{f}^\ell &= (f^{-\ell}_\ell,f^{-\ell+1}_\ell,\dots, f^\ell_\ell). \end{aligned}

The fact that each f\mathbf{f}^\ell is SO(3)SO(3)-invariant follow from that the action SO(3)SO(3) on Y:=span{Ymm}\mathbf{Y}^\ell := \text{span}\{Y^m_\ell| |m|\leq \ell\} is representable by Wigner D-matrices, which are unitary.

Equivariant error analysis

The non-linearity layers are the only ones that introduce equivariant errors. To see this, we define the distribution

s=2π2bj=02b1k=02b1wj(b)δxi,j, s = \frac{\sqrt{2\pi}}{2b} \sum_{j=0}^{2b-1}\sum_{k=0}^{2b-1} w^{(b)}_j\delta_{x_{i,j}},

given the equi-angular grid {xi,j}\{x_{i,j}\}.

We note that given a function (or a nice enough distribution) in L2(S2)L^2(\mathbb{S}^2). When we use the sampling algorithm to approximate the Fourier coefficients, we are implicitly implementing an orthogonal projection of the atomic distribution

fs=2π2bj=02b1k=02b1wj(b)f(xi,j)δxi,j fs = \frac{\sqrt{2\pi}}{2b} \sum_{j=0}^{2b-1}\sum_{k=0}^{2b-1} w^{(b)}_jf(x_{i,j})\delta_{x_{i,j}}

onto the subspace Y[b]:=span{Ymm<b}\mathbf{Y}^{[b]}:=\text{span}\{Y^m_\ell| |m| \leq \ell < b\}. We can see that the operation ffsf\mapsto fs is not SO(3)SO(3)-equivariant as

(gf)s=2π2bj=02b1k=02b1wj(b)f(gxi,j)δxi,j,g(fs)=2π2bj=02b1k=02b1wj(b)f(xi,j)δg1xi,j.\begin{aligned} (g\cdot f)s &= \frac{\sqrt{2\pi}}{2b} \sum_{j=0}^{2b-1}\sum_{k=0}^{2b-1} w^{(b)}_jf(g\cdot x_{i,j})\delta_{x_{i,j}},\\ g\cdot (fs) &= \frac{\sqrt{2\pi}}{2b} \sum_{j=0}^{2b-1}\sum_{k=0}^{2b-1} w^{(b)}_jf(x_{i,j})\delta_{g^{-1}\cdot x_{i,j}}.\\ \end{aligned}

If we write f=fb+frf = f_b + f_r where fbY[b]f_b\in \mathbf{Y}^{[b]} and fbfrf^b \perp f^r, we can see that operation ProjY[b](s)\text{Proj}_{\mathbf{Y}^{[b]}}(\cdot s) is only SO(3)SO(3)-equivariant on the set

{fL2(S2) (gfr)sg(frs)Y[b]for allgSO(3)}. \{f\in L^2(\mathbb{S}^2)~|~(g\cdot f_r)s - g\cdot (f_rs) \perp \mathbf{Y}^{[b]}\quad \text{for all}\quad g\in SO(3)\}.