Skip to content

Commit

Permalink
test when 1 rep
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard2926 committed Apr 15, 2024
1 parent acbc932 commit 311a873
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
6 changes: 3 additions & 3 deletions scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ config = ARGS[8]

# For scaling tests, use 4 modes, training use 25% modes

modesx = 4 # max(dimx÷32, 4)
modesy = 4 # max(dimy÷32, 4)
modesz = 4 # max(dimz÷32, 4)
modesx = 8 # max(dimx÷32, 4)
modesy = 8 # max(dimy÷32, 4)
modesz = 8 # max(dimz÷32, 4)
modest = 4 # max(dimt÷32, 4)

(gpus > 64) && (modesy = modesy * 2)
Expand Down
26 changes: 9 additions & 17 deletions src/models/DFNO_3D/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,15 @@ mutable struct Model
fourier_y = ParDFT(Complex{T}, config.ny)
fourier_z = ParDFT(Complex{T}, config.nz)
fourier_t = ParDFT(T, config.nt)

restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx])
restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny])
restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:config.mz, config.nz-config.mz+1:config.nz])
restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt])

input_shape = (config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz))
weight_shape = (config.nc_lift, config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz))

# # Build restrictions to low-frequency modes
# restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:mx, config.nx-mx+1:config.nx])
# restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:my, config.ny-my+1:config.ny])
# restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:mz, config.nz-mz+1:config.nz])
# restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:mt])

# input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz)
# weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz)

# Build restrictions to low-frequency modes
restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:mx, config.nx-mx+1:config.nx])
restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:my, config.ny-my+1:config.ny])
restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:mz, config.nz-mz+1:config.nz])
restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:mt])

input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz)
weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz)

input_order = (1, 2, 3)
weight_order = (1, 4, 2, 3)
Expand Down

0 comments on commit 311a873

Please sign in to comment.