diff --git a/README.md b/README.md index 611f9f3..6287c19 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This repository featuers two parts: - Jax-cfd's `funcutils.trajectory` function supports to track only one field variable (vorticity or velocity), Extra fields computation and tracking are made easier, such as time derivatives and PDE residual $R(\boldsymbol{v}):=\boldsymbol{f}-\partial_t \boldsymbol{v}-(\boldsymbol{v}\cdot\nabla)\boldsymbol{v} + \nu \Delta \boldsymbol{v}$. - All ops takes batch dimension of tensors into consideration, not a single trajectory. - Neural Operator-Assisted Navier-Stokes Equations solver. - - The **Spatiotempoeral Fourier Neural Operator** (SFNO) that is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), inspiration drawn from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator). + - The **Spatiotempoeral Fourier Neural Operator** (SFNO) that is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`sfno` directory](sfno/). Inspirations are drawn from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator). - Data generation for the meta-example of the isotropic turbulence with energy spectra matching the inverse cascade of Kolmogorov flow in a periodic box. Ref: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. *Journal of Fluid Mechanics*, 146, 21-43. - Pipelines for the *a posteriori* error estimation to fine-tune the SFNO to reach the scientific computing level of accuracy ($\le 10^{-6}$) in Bochner norm using FLOPs on par with a single evaluation, and only a fraction of FLOPs of a single `.backward()`. - Example files will be added later after cleanup. @@ -21,6 +21,7 @@ If one wants to play with the neural operator part, it is recommended to clone t ## Data The data are available at https://huggingface.co/datasets/scaomath/navier-stokes-dataset +Data generation instructions are available in the [SFNO folder](/sfno/) ## Examples diff --git a/fno/data/data_gen_FNO.py b/fno/data/data_gen_FNO.py deleted file mode 100644 index 0980792..0000000 --- a/fno/data/data_gen_FNO.py +++ /dev/null @@ -1,194 +0,0 @@ -import argparse -import math -import os -from functools import partial - -import torch -import torch.fft as fft -import torch.nn.functional as F -from .grf import GRF2d -from .solvers import * -from .data_gen import * - - -def main(): - """ - Generate the original FNO data - the right hand side is a fixed forcing - 0.1*(torch.sin(2*math.pi*(x+y))+torch.cos(2*math.pi*(x+y))) - - It stores data after each batch, and will resume using a fixed formula'd seed - when starting again. - The default values of the params for the Gaussian Random Field (GRF) are printed. - - Sample usage: - - - Generate new data with 256 grid size with double and extra variables: - > python ./src/data_gen_FNO.py --sample-size 8 --batch-size 8 --grid-size 256 --double --extra-vars --time 40 --num-steps 100 --dt 1e-3 - - """ - current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm") - log_name = "".join(os.path.basename(__file__).split(".")[:-1]) - - log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log") - logger = get_logger(log_filename) - - args = get_args() - cuda = not args.no_cuda and torch.cuda.is_available() - device = torch.device("cuda" if cuda else "cpu") - logger.info(f"Using device: {device}") - logger.info(f"Using the following arguments: ") - all_args = {k: v for k, v in vars(args).items() if not callable(v)} - logger.info("\n".join(f"{k}={v}" for k, v in all_args.items())) - - n_grid_max = 2048 - n = args.grid_size # 256 - subsample = args.subsample # 4 - ns = n // subsample - diam = args.diam # 1.0 - diam = eval(diam) if isinstance(diam, str) else diam - if n > n_grid_max: - raise ValueError( - f"Grid size {n} is larger than the maximum allowed {n_grid_max}" - ) - visc = args.visc # 1e-3 - T = args.time # 50 - delta_t = args.dt # 1e-4 - alpha = args.alpha # 2.5 - tau = args.tau # 7 - f = args.forcing # FNO's default sin+cos - dtype = torch.float64 if args.double else torch.float32 - normalize = args.normalize - filename = args.filename - force_rerun = args.force_rerun - replicate_init = args.replicable_init - dealias = not args.no_dealias - torch.set_default_dtype(dtype) - - # Number of solutions to generate - N_samples = args.sample_size # 8 - - # Number of snapshots from solution - record_steps = args.num_steps - - # Batch size - bsz = args.batch_size # 8 - - extra = "_extra" if args.extra_vars else "" - dtype_str = "_fp64" if args.double else "" - if filename is None: - filename = ( - f"ns_data{extra}{dtype_str}" - + f"_N{N_samples}_n{ns}" - + f"_v{visc:.0e}_T{T}" - + f"_alpha{alpha:.1f}_tau{tau:.0f}.pt" - ) - args.filename = filename - filepath = os.path.join(DATA_PATH, filename) - - data_exist = os.path.exists(filepath) - if data_exist: - logger.info(f"\nFile {filename} exists with current data as follows:") - data = torch.load(filepath) - for key, v in data.items(): - if isinstance(v, torch.Tensor): - logger.info(f"{key:<12}", "\t", v.shape) - else: - logger.info(f"{key:<12}", "\t", v) - if len(data[key]) == N_samples: - return - if force_rerun: - logger.info(f"\nRegenerating data and saving in {filename}\n") - else: - logger.info(f"\nGenerating data and saving in {filename}\n") - - # Set up 2d GRF with covariance parameters - # Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha) - # Note that we need alpha > d/2 (here d= 2) - grf = GRF2d( - n=n, - alpha=alpha, - tau=tau, - normalize=normalize, - device=device, - dtype=dtype, - ) - - # Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y))) - grid = torch.linspace(0, 1, n + 1, device=device) - grid = grid[0:-1] - - X, Y = torch.meshgrid(grid, grid, indexing="ij") - # FNO's original implementation - # fh = 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))) - fh = f(X, Y) - - if data_exist and not force_rerun: - w_init = [x for x in data["a"]] - w = [x for x in data["u"]] - w_t = [x for x in data["vort_t"]] - psi = [x for x in data["stream"]] - res = [x for x in data["residual"]] - seeds = [x for x in data["seeds"]] - N_existing = len(w0) - else: - w_init = [] - w = [] - w_t = [] - psi = [] - res = [] - seeds = [] - N_existing = 0 - - if N_existing >= N_samples: # No need to generate more data - return - - for i in range((N_samples - N_existing) // bsz): - # Sample random fields - seed = args.seed + N_existing + i * bsz - seeds.append(seed) - if replicate_init: - w0 = grf.sample(bsz, n_grid_max, random_state=seed) - w0 = F.interpolate(w0.unsqueeze(1), size=(n, n), mode="nearest") - w0 = w0.squeeze(1) - else: - w0 = grf.sample(bsz, n, random_state=seed) - - result = get_trajectory_imex_crank_nicolson( - w0, - fh, - visc=visc, - T=T, - delta_t=delta_t, - record_steps=record_steps, - diam=diam, - dealias=dealias, - subsample=subsample, - ) - - if not extra: - for key in ["vort_t", "stream", "res"]: - result[key] = torch.empty(0, device="cpu") - - w_init.append(w0) - w.append(result["vorticity"]) - w_t.append(result["vorticity_t"]) - psi.append(result["stream"]) - res.append(result["residual"]) - - results = { - "w0": torch.cat(w_init), - "w": torch.cat(w), - "dwdt": torch.cat(w_t), - "stream": torch.cat(psi), - "residual": torch.cat(res), - "t": result["t_steps"], - "f": fh.cpu(), - "seeds": seeds, - } - torch.save(results, filepath) - return - - -if __name__ == "__main__": - main() diff --git a/fno/README.md b/sfno/README.md similarity index 100% rename from fno/README.md rename to sfno/README.md diff --git a/fno/__init__.py b/sfno/__init__.py similarity index 100% rename from fno/__init__.py rename to sfno/__init__.py diff --git a/fno/data/__init__.py b/sfno/data/__init__.py similarity index 100% rename from fno/data/__init__.py rename to sfno/data/__init__.py diff --git a/fno/data/data_gen.py b/sfno/data/data_gen.py similarity index 71% rename from fno/data/data_gen.py rename to sfno/data/data_gen.py index 2c7078f..192cfaf 100644 --- a/fno/data/data_gen.py +++ b/sfno/data/data_gen.py @@ -11,6 +11,7 @@ import torch.nn.functional as F import xarray from .solvers import * +from torch_cfd.equations import * import os from datetime import datetime @@ -29,6 +30,8 @@ feval = lambda s: eval("lambda x, y:" + s, globals()) +TQDM_ITERS = 200 + class TqdmLoggingHandler(logging.Handler): def __init__(self, level=logging.NOTSET): @@ -57,16 +60,95 @@ def get_logger(filename, tqdm=True): return logging.getLogger() +def interp2d(x, **kwargs): + expand_dims = [None] * (4 - x.ndim) + x = x[*expand_dims, ...] + return F.interpolate(x, **kwargs).squeeze() + + +def get_trajectory_rk4( + equation: ImplicitExplicitODE, + w0: Array, + dt: float, + num_steps: int = 1, + record_every_steps: int = 1, + pbar=False, + pbar_desc="generating trajectories using RK4", + require_grad=False, +): + """ + vorticity stacked in the time dimension + all inputs and outputs are in the frequency domain + input: w0 (*, n, n) + output: + + vorticity (*, n_t, kx, ky) + psi: (*, n_t, kx, ky) + + velocity can be computed from psi + (*, 2, n_t, kx, ky) by calling spectral_rot_2d + """ + w_all = [] + dwdt_all = [] + res_all = [] + psi_all = [] + w = w0 + tqdm_iters = num_steps if TQDM_ITERS > num_steps else TQDM_ITERS + update_iters = num_steps // tqdm_iters + with tqdm(total=num_steps) as pbar: + for t_step in range(num_steps): + w, dwdt = equation.forward(w, dt=dt) + w.requires_grad_(require_grad) + dwdt.requires_grad_(require_grad) + + if t_step % update_iters == 0: + res = equation.residual(w, dwdt) + res_norm = torch.linalg.norm(res).item()/w0.size(-1) + res_desc = f" unnormalized \|L(w) - f\|_2: {res_norm:.4e}" + desc = ( + datetime.now().strftime("%d-%b-%Y %H:%M:%S") + + " - " + + pbar_desc + + res_desc + ) + pbar.set_description(desc) + pbar.update(update_iters) + + if t_step % record_every_steps == 0: + _, psi = vorticity_to_velocity(equation.grid, w) + res = equation.residual(w, dwdt) + + w_, dwdt_, psi, res = [ + var.detach().cpu().clone() for var in [w, dwdt, psi, res] + ] + + w_all.append(w_) + psi_all.append(psi) + dwdt_all.append(dwdt_) + res_all.append(res) + + result = { + var_name: torch.stack(var, dim=-3) + for var_name, var in zip( + ["vorticity", "stream", "vort_t", "residual"], + [w_all, psi_all, dwdt_all, res_all], + ) + } + return result + + def get_trajectory_imex_crank_nicolson( w0, f, - visc, - T, + visc=1e-3, + T=1, delta_t=1e-3, record_steps=1, diam=1, dealias=True, subsample=1, + dtype=None, + pbar=True, **kwargs, ): """ @@ -83,11 +165,10 @@ def get_trajectory_imex_crank_nicolson( - vorticity, time derivative of vorticity, streamfunction, residual """ # Grid size - must be power of 2 - size, device, dtype = w0.size(), w0.device, w0.dtype - bsz, n = size[0], size[-1] - interp2d = partial( - F.interpolate, size=(n // subsample, n // subsample), mode="bilinear" - ) + dtype = w0.dtype if dtype is None else dtype + device = w0.device + bsz, n = w0.size(0), w0.size(-1) + ns = n // subsample # Maximum frequency k_max = math.floor(n / 2.0) @@ -124,7 +205,7 @@ def get_trajectory_imex_crank_nicolson( kx, ky, lap = kx[None, ...], ky[None, ...], lap[None, ...] # Dealiasing mask - dealiasing_filter = ( + dealias_filter = ( torch.unsqueeze( torch.logical_and( torch.abs(kx) <= (2.0 / 3.0) * k_max, @@ -139,7 +220,7 @@ def get_trajectory_imex_crank_nicolson( ) # Saving solution and time - size = bsz, record_steps, n // subsample, n // subsample + size = bsz, record_steps, ns, ns vort, vort_t, stream, residual = [ torch.empty(*size, device="cpu") for _ in range(4) ] @@ -156,14 +237,15 @@ def get_trajectory_imex_crank_nicolson( residualL2 = norm(res, dim=(-1, -2)).mean() / n desc = ( - f"enstrophy w: {enstrophy:.4f} \ " + datetime.now().strftime("%d-%b-%Y %H:%M:%S") + + f" - enstrophy w: {enstrophy:.4f} \ " + f"||L(w, psi) - f||_L2: {residualL2:.4e} \ " ) - with tqdm(total=total_steps, desc=desc) as pbar: + with tqdm(total=total_steps, desc=desc, disable=not pbar) as pb: for j in range(total_steps): - w_h, _, w_h_t, psi_h, res_h = imex_crank_nicolson_step( + w_h, w_h_t, _, psi_h, res_h = imex_crank_nicolson_step( w_h, f_h, visc, @@ -171,7 +253,8 @@ def get_trajectory_imex_crank_nicolson( diam=diam, rfftmesh=(kx, ky), laplacian=lap, - dealias_filter=dealiasing_filter, + dealias_filter=dealias_filter, + dealias=dealias, **kwargs, ) @@ -197,17 +280,17 @@ def get_trajectory_imex_crank_nicolson( visc, (kx, ky), lap, - dealiasing_filter, + dealias_filter=dealias_filter, dealias=dealias, ) res = fft.irfft2(res_h, s=(n, n)).real if subsample > 1: w, w_t, psi, res = ( - interp2d(w), - interp2d(w_t), - interp2d(psi), - interp2d(res), + interp2d(w, size=(ns, ns), mode="bilinear"), + interp2d(w_t, size=(ns, ns), mode="bilinear"), + interp2d(psi, size=(ns, ns), mode="bilinear"), + interp2d(res, size=(ns, ns), mode="bilinear"), ) # Record solution and time vort[:, c] = w.detach().cpu() @@ -221,11 +304,12 @@ def get_trajectory_imex_crank_nicolson( residualL2 = norm(res, dim=(-1, -2)).mean() / n divider = {0: "|", 1: "/", 2: "-", 3: "\\"} desc = ( - f"enstrophy w: {enstrophy:.4f} {divider[c%4]} " + datetime.now().strftime("%d-%b-%Y %H:%M:%S") + + f" - enstrophy w: {enstrophy:.4f} {divider[c%4]} " + f" ||L(w, psi) - f||_2: {residualL2:.4e} {divider[c%4]} " ) - pbar.set_description(desc) - pbar.update() + pb.set_description(desc) + pb.update() return dict( vorticity=vort, @@ -357,11 +441,25 @@ def get_args(): metavar="v_max", help="the maximum speed in the init velocity field (default: 5)", ) + parser.add_argument( + "--filepath", + type=str, + default=None, + metavar="file path", + help="path to save the data (default: None)", + ) + parser.add_argument( + "--logpath", + type=str, + default=None, + metavar="log path", + help="path to save the logs (default: None)", + ) parser.add_argument( "--filename", type=str, default=None, - metavar="filename", + metavar="file name", help="file name for Navier-Stokes data (default: None)", ) parser.add_argument( @@ -391,6 +489,18 @@ def get_args(): default=False, help="Disable the dealias masking to the nonlinear convection term", ) + parser.add_argument( + "--no-tqdm", + action="store_true", + default=False, + help="Disable program bar for data generation", + ) + parser.add_argument( + "--demo-plots", + action="store_true", + default=False, + help="plot several trajectories for the generated data", + ) parser.add_argument( "--seed", type=int, @@ -402,6 +512,23 @@ def get_args(): return parser.parse_args() +def save_pickle(data, save_path, append=True): + mode = "ab" if append else "wb" + with open(save_path, mode) as f: + dill.dump(data, f) + + +def load_pickle(load_path, mode="rb"): + data = [] + with open(load_path, mode=mode) as f: + try: + while True: + data.append(dill.load(f)) + except EOFError: + pass + return data + + def pickle_to_pt(data_path, save_path=None): """ convert serialized data from pickle to pytorch pt file @@ -409,13 +536,14 @@ def pickle_to_pt(data_path, save_path=None): https://stackoverflow.com/a/28745948/622119 """ save_path = data_path.replace(".pkl", ".pt") if save_path is None else save_path - result = [] - with open(data_path, "rb") as f: - while True: - try: - result.append(dill.load(f)) - except EOFError: - break + # result = [] + # with open(data_path, "rb") as f: + # while True: + # try: + # result.append(dill.load(f)) + # except EOFError: + # break + result = load_pickle(data_path) data = defaultdict(list) for _res in result: @@ -423,7 +551,10 @@ def pickle_to_pt(data_path, save_path=None): data[field].append(value) for field, value in data.items(): - data[field] = torch.cat(value) + v = torch.cat(value) + if v.ndim == 1: # time steps or seed + v = torch.unique(v) + data[field] = v torch.save(data, data_path) diff --git a/sfno/data/data_gen_FNO.py b/sfno/data/data_gen_FNO.py new file mode 100644 index 0000000..4aca97e --- /dev/null +++ b/sfno/data/data_gen_FNO.py @@ -0,0 +1,233 @@ +import argparse +import math +import os +from functools import partial + +import torch +import torch.fft as fft +import torch.nn.functional as F + +from .grf import GRF2d +from .solvers import * +from .data_gen import * + + +def main(args): + """ + Generate the original FNO data + the right hand side is a fixed forcing + 0.1*(torch.sin(2*math.pi*(x+y))+torch.cos(2*math.pi*(x+y))) + + It stores data after each batch, and will resume using a fixed formula'd seed + when starting again. + The default values of the params for the Gaussian Random Field (GRF) are printed. + + Sample usage: + + - Training data for paper + >>> python data_gen_FNO.py --sample-size 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3 + + - Test data + >>> python data_gen_FNO.py --sample-size 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --replicable-init --seed 42 + + - Test data fine + >>> python data_gen_FNO.py --sample-size 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --replicable-init --seed 42 + + """ + + current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm") + log_name = "".join(os.path.basename(__file__).split(".")[:-1]) + logpath = args.logpath if args.logpath is not None else LOG_PATH + log_filename = os.path.join(logpath, f"{current_time}_{log_name}.log") + logger = get_logger(log_filename) + + cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if cuda else "cpu") + logger.info(f"Using device: {device}") + logger.info(f"Using the following arguments: ") + all_args = {k: v for k, v in vars(args).items() if not callable(v)} + logger.info(" | ".join(f"{k}={v}" for k, v in all_args.items())) + + n_grid_max = 2048 + n = args.grid_size # 256 + subsample = args.subsample # 4 + ns = n // subsample + diam = args.diam # 1.0 + diam = eval(diam) if isinstance(diam, str) else diam + if n > n_grid_max: + raise ValueError( + f"Grid size {n} is larger than the maximum allowed {n_grid_max}" + ) + visc = args.visc # 1e-3 + T = args.time # 50 + T_warmup = args.time_warmup # 30 + T_new = T - T_warmup + delta_t = args.dt # 1e-4 + + alpha = args.alpha # 2.5 + tau = args.tau # 7 + f = args.forcing # FNO's default sin+cos + dtype = torch.float64 if args.double else torch.float32 + normalize = args.normalize + filename = args.filename + force_rerun = args.force_rerun + replicate_init = args.replicable_init + dealias = not args.no_dealias + pbar = not args.no_tqdm + torch.set_default_dtype(dtype) + + # Number of solutions to generate + N_samples = args.sample_size # 8 + + # Number of snapshots from solution + record_steps = args.num_steps + + # Batch size + batch_size = args.batch_size # 8 + + solver_kws = dict( + visc=visc, delta_t=delta_t, diam=diam, dealias=dealias, dtype=torch.float64 + ) + + extra = "_extra" if args.extra_vars else "" + dtype_str = "_fp64" if args.double else "" + if filename is None: + filename = ( + f"fnodata{extra}{dtype_str}_{ns}x{ns}_N{N_samples}" + + f"_v{visc:.0e}_T{int(T)}_steps{record_steps}_alpha{alpha:.1f}_tau{tau:.0f}.pt" + ).replace("e-0", "e-") + args.filename = filename + + filepath = args.filepath if args.filepath is not None else DATA_PATH + for p in [filepath]: + if not os.path.exists(p): + os.makedirs(p) + logging.info(f"Created directory {p}") + data_filepath = os.path.join(DATA_PATH, filename) + + data_exist = os.path.exists(data_filepath) + if data_exist and not force_rerun: + logger.info(f"File {filename} exists with current data as follows:") + data = torch.load(data_filepath) + + for key, v in data.items(): + if isinstance(v, torch.Tensor): + logger.info(f"{key:<12} | {v.shape} | {v.dtype}") + else: + logger.info(f"{key:<12} | {v.dtype}") + if len(data[key]) == N_samples: + return + elif len(data[key]) < N_samples: + N_samples -= len(data[key]) + else: + logger.info(f"Generating data and saving in {filename}") + + # Set up 2d GRF with covariance parameters + # Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha) + # Note that we need alpha > d/2 (here d= 2) + grf = GRF2d( + n=n, + alpha=alpha, + tau=tau, + normalize=normalize, + device=device, + dtype=torch.float64, + ) + + # Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y))) + grid = torch.linspace(0, 1, n + 1, device=device) + grid = grid[0:-1] + + X, Y = torch.meshgrid(grid, grid, indexing="ij") + # FNO's original implementation + # fh = 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))) + fh = f(X, Y) + + if os.path.exists(data_filepath) and not force_rerun: + logger.info(f"Data already exists at {data_filepath}") + return + elif os.path.exists(data_filepath) and force_rerun: + logger.info(f"Force rerun and save data to {data_filepath}") + os.remove(data_filepath) + else: + logger.info(f"Save data to {data_filepath}") + + for i, idx in enumerate(range(0, N_samples, batch_size)): + logger.info(f"Generate trajectory for {i+1}-th batch of {N_samples} samples") + logger.info( + f"random state: {args.seed + idx} to {args.seed + idx + batch_size-1}" + ) + + # Sample random fields + seeds = [args.seed + idx + k for k in range(batch_size)] + n0 = n_grid_max if replicate_init else n + w0 = [ + grf.sample(1, n0, random_state=s) for _, s in zip(range(batch_size), seeds) + ] + w0 = torch.stack(w0) + if n != n0: + w0 = F.interpolate(w0, size=(n, n), mode="nearest") + w0 = w0.squeeze(1) + + logger.info(f"initial condition {w0.shape}") + + if T_warmup > 0: + logger.info(f"warm up till {T_warmup}") + tmp = get_trajectory_imex_crank_nicolson( + w0, + fh, + T=T_warmup, + record_steps=record_steps, + subsample=1, + pbar=pbar, + **solver_kws, + ) + w0 = tmp["vorticity"][:, -1].to(device) + del tmp + logger.info(f"warmup initial condition {w0.shape}") + + logger.info(f"generate data from {T_warmup} to {T}") + result = get_trajectory_imex_crank_nicolson( + w0, + fh, + T=T_new, + record_steps=record_steps, + subsample=subsample, + pbar=pbar, + **solver_kws, + ) + + for field, value in result.items(): + if subsample > 1 and value.ndim == 4: + value = F.interpolate(value, size=(ns, ns), mode="bilinear") + result[field] = value.cpu().to(dtype) + logger.info(f"{field:<15} | {value.shape} | {value.dtype}") + + if not extra: + for key in ["vort_t", "stream", "residual"]: + result[key] = torch.empty(0, device="cpu") + result["random_states"] = torch.as_tensor(seeds, dtype=torch.int32) + + logger.info(f"Saving {i+1}-th batch to {data_filepath}") + save_pickle(result, data_filepath) + + pickle_to_pt(data_filepath) + logger.info(f"Done converting to pt.") + if args.demo_plots: + try: + verify_trajectories( + data_filepath, + dt=T_new / record_steps, + T_warmup=T_warmup, + n_samples=1, + ) + except Exception as e: + logger.error(f"Error in plotting: {e}") + finally: + pass + return + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/fno/data/data_gen_McWilliams2d.py b/sfno/data/data_gen_McWilliams2d.py similarity index 78% rename from fno/data/data_gen_McWilliams2d.py rename to sfno/data/data_gen_McWilliams2d.py index 7481c52..056acda 100644 --- a/fno/data/data_gen_McWilliams2d.py +++ b/sfno/data/data_gen_McWilliams2d.py @@ -32,8 +32,12 @@ def main(args): [1]: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. Journal of Fluid Mechanics, 146, 21-43. Training dataset: - python3 data_gen_McWilliams2d.py --sample-size 1280 --grid-size 256 --subsample 4 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2 * torch.pi" --double + >>> python3 data_gen_McWilliams2d.py --sample-size 1280 --grid-size 256 --subsample 4 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double + + Testing dataset for plotting the enstrohpy spectrum: + >>> python3 data_gen_McWilliams2d.py --sample-size 32 --batch-size 8 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double """ + current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm") log_name = "".join(os.path.basename(__file__).split(".")[:-1]) @@ -72,10 +76,8 @@ def main(args): dtype_str = "_fp64" if args.double else "" filename = args.filename if filename is None: - filename = ( - f"McWilliams2d{dtype_str}_{ns}x{ns}" - + f"_N{total_samples}_v{viscosity:.0e}" - + f"_T{num_snapshots}.pt".replace("e-0", "e-") + filename = f"McWilliams2d{dtype_str}_{ns}x{ns}_N{total_samples}_v{viscosity:.0e}_T{num_snapshots}.pt".replace( + "e-0", "e-" ) args.filename = filename data_filepath = os.path.join(DATA_PATH, filename) @@ -89,6 +91,7 @@ def main(args): logger.info(f"Save data to {data_filepath}") cuda = not args.no_cuda and torch.cuda.is_available() + no_tqdm = args.no_tqdm device = torch.device("cuda:0" if cuda else "cpu") torch.set_default_dtype(dtype) @@ -106,7 +109,9 @@ def main(args): ).to(device) for i, idx in enumerate(range(0, total_samples, batch_size)): - logger.info(f"Generate trajectory for {i+1}-th batch of {total_samples}") + logger.info( + f"Generate trajectory for {i+1}-th batch of {total_samples} samples" + ) logger.info( f"random state: {random_state + idx} to {random_state + idx + batch_size-1}" ) @@ -121,19 +126,21 @@ def main(args): ) vort_hat = fft.rfft2(vort_init).to(device) - with tqdm(total=warmup_steps) as pbar: + with tqdm(total=warmup_steps, disable=no_tqdm) as pbar: for j in range(warmup_steps): vort_hat, _ = ns2d.step(vort_hat, dt) if j % 100 == 0: + desc = datetime.now().strftime("%d-%b-%Y %H:%M:%S") + ' - Warmup' + pbar.set_description(desc) pbar.update(100) - result = get_trajectory( + result = get_trajectory_rk4( ns2d, vort_hat, dt, num_steps=total_steps, record_every_steps=record_every_iters, - pbar=True, + pbar=not no_tqdm, ) for field, value in result.items(): @@ -149,17 +156,23 @@ def main(args): result["random_states"] = torch.tensor( [random_state + idx + k for k in range(batch_size)], dtype=torch.int32 ) - logger.info(f"Save {i}-th batch to {data_filepath}") - with open(data_filepath, "ab") as f: - dill.dump(result, f) + logger.info(f"Save {i+1}-th batch to {data_filepath}") + save_pickle(result, data_filepath) pickle_to_pt(data_filepath) - - verify_trajectories( - data_filepath, dt=record_every_iters * dt, T_warmup=T_warmup, n_samples=1 - ) + logger.info(f"Done saving.") + if args.demo_plots: + try: + verify_trajectories( + data_filepath, + dt=record_every_iters * dt, + T_warmup=T_warmup, + n_samples=1, + ) + except Exception as e: + logger.error(f"Error in plotting: {e}") if __name__ == "__main__": args = get_args() - main(args) + main(args) \ No newline at end of file diff --git a/fno/data/grf.py b/sfno/data/grf.py similarity index 100% rename from fno/data/grf.py rename to sfno/data/grf.py diff --git a/fno/data/solvers.py b/sfno/data/solvers.py similarity index 100% rename from fno/data/solvers.py rename to sfno/data/solvers.py diff --git a/fno/datasets.py b/sfno/datasets.py similarity index 100% rename from fno/datasets.py rename to sfno/datasets.py diff --git a/fno/fno3d.py b/sfno/fno3d.py similarity index 100% rename from fno/fno3d.py rename to sfno/fno3d.py diff --git a/fno/losses.py b/sfno/losses.py similarity index 100% rename from fno/losses.py rename to sfno/losses.py diff --git a/fno/pipeline.py b/sfno/pipeline.py similarity index 100% rename from fno/pipeline.py rename to sfno/pipeline.py diff --git a/fno/sfno.py b/sfno/sfno.py similarity index 100% rename from fno/sfno.py rename to sfno/sfno.py diff --git a/fno/utils.py b/sfno/utils.py similarity index 100% rename from fno/utils.py rename to sfno/utils.py diff --git a/fno/visualizations.py b/sfno/visualizations.py similarity index 100% rename from fno/visualizations.py rename to sfno/visualizations.py diff --git a/torch_cfd/equations.py b/torch_cfd/equations.py index 5740579..ca81637 100644 --- a/torch_cfd/equations.py +++ b/torch_cfd/equations.py @@ -24,14 +24,12 @@ from . import grids -TQDM_ITERS = 500 - Array = torch.Tensor Grid = grids.Grid -def spectral_laplacian_2d(rfft_mesh): - kx, ky = rfft_mesh +def spectral_laplacian_2d(fft_mesh): + kx, ky = fft_mesh # (2 * torch.pi * 1j)**2 lap = -4 * (torch.pi) ** 2 * (abs(kx) ** 2 + abs(ky) ** 2) lap[..., 0, 0] = 1 @@ -48,6 +46,15 @@ def spectral_curl_2d(vhat, rfft_mesh): return 2j * torch.pi * (vhat * kx - uhat * ky) +def spectral_div_2d(vhat, rfft_mesh): + r""" + Computes the 2D divergence in the Fourier basis. + """ + uhat, vhat = vhat + kx, ky = rfft_mesh + return 2j * torch.pi * (uhat * kx + vhat * ky) + + def spectral_grad_2d(vhat, rfft_mesh): kx, ky = rfft_mesh return 2j * torch.pi * kx * vhat, 2j * torch.pi * ky * vhat @@ -164,6 +171,14 @@ def implicit_solve( """Solves `u - step_size * implicit_terms(u) = f` for u.""" raise NotImplementedError + def residual( + self, + u: Array, + u_t: Array, + ): + """Computes the residual of the PDE.""" + raise NotImplementedError + def backward_forward_euler( u: torch.Tensor, @@ -442,65 +457,3 @@ def forward(self, vort_hat, dt, steps=1): vort_hat = self.solver(vort_hat, dt, self) dvortdt_hat = 1 / (steps * dt) * (vort_hat - vort_old) return vort_hat, dvortdt_hat - - -def get_trajectory( - equation: ImplicitExplicitODE, - w0: Array, - dt: float, - num_steps: int = 1, - record_every_steps: int = 1, - pbar=False, - pbar_desc="", - require_grad=False, -): - """ - vorticity stacked in the time dimension - all inputs and outputs are in the frequency domain - input: w0 (*, n, n) - output: - - vorticity (*, n_t, kx, ky) - psi: (*, n_t, kx, ky) - - velocity can be computed from psi - (*, 2, n_t, kx, ky) by calling spectral_rot_2d - """ - w_all = [] - dwdt_all = [] - res_all = [] - psi_all = [] - w = w0 - tqdm_iters = num_steps if TQDM_ITERS > num_steps else TQDM_ITERS - update_iters = num_steps // tqdm_iters - with tqdm(total=num_steps) as pbar: - for t_step in range(num_steps): - w, dwdt = equation.forward(w, dt=dt) - w.requires_grad_(require_grad) - dwdt.requires_grad_(require_grad) - - if t_step % update_iters == 0: - pbar.set_description(pbar_desc) - pbar.update(update_iters) - - if t_step % record_every_steps == 0: - _, psi = vorticity_to_velocity(equation.grid, w) - res = equation.residual(w, dwdt) - - w_, dwdt_, psi, res = [ - var.detach().cpu().clone() for var in [w, dwdt, psi, res] - ] - - w_all.append(w_) - psi_all.append(psi) - dwdt_all.append(dwdt_) - res_all.append(res) - - result = { - var_name: torch.stack(var, dim=-3) - for var_name, var in zip( - ["vorticity", "stream", "vort_t", "residual"], - [w_all, psi_all, dwdt_all, res_all], - ) - } - return result