Skip to content

A pytorch framework for solving PDEs via Physics Informed Neural Networks (PINNs)!

License

Notifications You must be signed in to change notification settings

JohnCSu/torch_DE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

55 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Torch DE

Torch_DE is a framework for solving differential equations (DE) using Physics Infomed Neural Networks (PINNs)!

Why?

Barrier to Entry

Existing PINN frameworks like Nvidia Modulus and DeepXDE are very powerful, but it can be difficult to understand the underneath and make changes to the training loop. Torch_DE is designed for researchers and beginners to get into developing and playing around with PINNs quickly.

From the ground up approach

Torch_DE's approach is to be highly modular between different components while still remaining close to standard pytorch syntax and workflow as much as possible. Torch_DE tries to minimise the abstraction, working from the ground up rather than top down.

The advantage of this is users can more easily implement there own custom solutions and also combine them into their workflow. For example, because torch DE sticks to pytroch syntax, incorporating workflows such as tensorboard or WanDb is significantly easier. You can also reduce boilerplate via Pytorch Lightning!

Notebook and Report friendly

Torch DE is currently focused on the creation of PINNs on 2D spatial geometries (with optional time dimension). 2D geometries is the focus as it can easily be visualised on a notebook or pyplot, while also having geometric complexity that 1D problems lack (such as holes and sharp corners).

TLDR Features

Installation

Currently to use torch DE clone the repository:

git clone https://github.com/JohnCSu/torch_DE.git
cd torch_DE
# Install Requirements
python3 -m pip install -r .
# Add to python path
python3 -m pip install -e .

A packaged version will be available in the future once I get things up to standard!

Components

Torch DE follows a very similar workflow to standard Pytorch:

import torch
from torch_DE import ...
# Define variables and derivatives needed:
input_vars = ['x','t']
output_vars =['u']
derivatives = ['u_x','u_xx','u_xy']

# Define Points, Dataset and Dataloader
points = torch.linspace(0,1,10_000)
dataset = PINN_dataset()
dataset.add_group('col_points',points,batch_size = 2000)
DL = PINN_Dataloader(dataset)

#Define Loss function:
losses = ...

# Define Network PINN and optimizer
net = MLP
PINN = DE_getter(net = MLP,derivatives = derivatives)
optimiser = torch.optim.Adam(net)
optimiser.zero_grad()

for x in DL:
    x = x.to('cuda')
    output = PINN(x)
    loss = losses(output).sum()
    loss.backwards()
    optimiser.step()
    optimiser.zero_grad()

Geometry

torch DE geometry is allows you to create simple 2D domains for training PINNs. We pick 2D as it can easily be displayed in a notebook compared to 3D (and 1D is just a line!):

from torch_DE.geometry.shapes import Circle,Rectangle,Domain2D
(xmin,xmax),(ymin,ymax) = (0,1), (0,0.41)
domain = Rectangle(((xmin,ymin),(xmax,ymax) ),'corners')
domain = Domain2D(base = domain)

hole = Circle((0.2,0.2),r = 0.05,num_points= 512)
domain.remove(hole,names= ['Cylinder'])

domain.plot()

Defining Derivatives

torch DE uses letter strings for defining input and output values:

input_vars = ['x','t']
output_vars =['u']

Only single letter variables names are currently supported. For derivatives we use the subscript notation e.g $u_x$ to define derivatives:

derivatives = ['u_x','u_xx','u_xy']
  • u_x is $u_x = \frac{\partial u}{\partial x}$
  • u_xx is $u_{xx} = \frac{\partial^2 u}{\partial x^2}$
  • u_xy is $u_{xy} = \frac{\partial^2 u}{\partial xy}$

And so on!

Continuous

torch DE currently operates on the classic continious mesh-free PINN by Raissi et al. At the heart of this project is the DE_getter object which turns extracting derivatives from networks from a indexing to dictionary strings! Give DE_getter a network, input, output variables and the derivatives you want and it will do the rest! Here's a simple 1D Spring example:

import torch
import torch.nn as nn
from torch_DE.continuous import DE_Getter

# Solving the Spring Equation u_tt = -u with u(0) = 0 and u_t(0) = 1

net = nn.Sequential(nn.Linear(1,200),nn.Tanh(),nn.Linear(200,1))
PINN = DE_Getter(net = net)
PINN.set_vars(input_vars= ['t'], output_vars= ['u'])
PINN.set_derivatives(derivatives=['u_t','u_tt'])
#Torch DE assumes first dimension is the batch dimension
t = torch.linspace(0,2*torch.pi,100).unsqueeze(dim =-1)

optimizer = torch.optim.Adam(params = net.parameters(), lr = 1e-3)
# Training Loop
for epoch in range(5000):
    #Calculate Derivatives
    out = PINN.calculate(t)
    #Output for DE_Getter is a dict of dictionaries containing tensors
    out = out['all']
    #Spring Equation is u_tt + u = 0
    residual = (out['u_tt'] + out['u']).pow(2).mean()

    #Data Fitting term
    data = out['u'][0].pow(2).mean() + (out['u_t'][0] - 1).pow(2).mean()

    loss = data + residual
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Combined with the other components, PINN training gets much easier to maintain and more effecient.

Currently, you can extract derivatives using autodiff or via finite differences!

Backend

Torch_DE uses torch.func module and vmap to extract derivatives so currently multi-gpu and models with buffers (such as batch norm) may not work as intended. Using has the advantages:

  • Functorch obtains all the derivatives (i.e. the Jacobian) of a certain order all at once
  • allows any derivative to be obtained
  • No need to mess with torch.autograd which means:
    • seting input data's requires_grad_ = True
    • Setting retain_graph =True and/or create_graph=True

Visualization

Plotting is always an annoying task so there is a lot of work currently being done in Torch DE on visualization. Stay tuned for updates!

Licence

MIT

Releases

No releases published

Packages

No packages published