Skip to content

Commit

Permalink
allow changing the noise amount to be added when stuck (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
cagrikymk committed Mar 31, 2024
1 parent 28280c2 commit 11941cb
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
14 changes: 11 additions & 3 deletions jaxreaxff/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
parse_and_save_force_field)
import math
from functools import partial
from jaxreaxff.helper import build_float_range_checker

def main():
# create parser for command-line arguments
Expand Down Expand Up @@ -133,6 +134,13 @@ def main():
help='R|Max number of clusters that can be used\n' +
'High number of clusters lowers the memory cost\n' +
'However, it increases compilation time,especially for cpus')
parser.add_argument('--perc_noise_when_stuck', metavar='percentage',
type=build_float_range_checker(0.0, 0.1),
default=0.04,
help='R|Percentage of the noise that will be added to the parameters\n' +
'when the optimizer is stuck.\n' +
'param_noise_i = (param_min_i, param_max_i) * perc_noise_when_stuck\n' +
'Allowed range: [0.0, 0.1]')
parser.add_argument('--seed', metavar='seed',
type=int,
default=0,
Expand All @@ -149,9 +157,9 @@ def main():
print("To use the GPU version, jaxlib with CUDA support needs to installed!")

# advanced options
advanced_opts = {"perc_err_change_thr":0.01, # if change in error is less than this threshold, add noise
"perc_noise_when_stuck":0.04, # noise percantage (wrt param range) to add when stuck
"perc_width_rest_search":0.15, # width of the restricted parameter search after iteration > rest_search_start
advanced_opts = {"perc_err_change_thr":0.01, # if change in error is less than this threshold, add noise
"perc_noise_when_stuck":args.perc_noise_when_stuck, # noise percantage (wrt param range) to add when stuck
"perc_width_rest_search":0.15, # width of the restricted parameter search after iteration > rest_search_start
}

onp.random.seed(args.seed)
Expand Down
17 changes: 17 additions & 0 deletions jaxreaxff/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,28 @@
from jaxreaxff.inter_list_counter import pool_handler_for_inter_list_count
from jax_md import dataclasses
from jax_md.reaxff.reaxff_forcefield import ForceField
import argparse

# Since we shouldnt access the private API (jaxlib), create a dummy jax array
# and get the type information from the array.
#from jaxlib.xla_extension import ArrayImpl as JaxArrayType
JaxArrayType = type(jnp.zeros(1))

def build_float_range_checker(min_v, max_v):
'''
Returns a function that can be used to validate fiven FP value
withing the allowed range ([min_v, max_v])
'''
def range_checker(arg):
try:
val = float(arg)
except ValueError:
raise argparse.ArgumentTypeError("Value must be a floating point number")
if val < min_v or val > max_v:
raise argparse.ArgumentTypeError("Value must be in range [" + str(min_v) + ", " + str(max_v)+"]")
return val
return range_checker

def get_params(force_field, params_list):
'''
Get the selected parameters from the force field
Expand Down

0 comments on commit 11941cb

Please sign in to comment.