Skip to content

Commit

Permalink
hopefully its done now
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 2, 2024
1 parent 0fedc39 commit 198bfe3
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,10 @@ def __init__(self,
self.latest_epoch = latest_epoch

if train_state == None:
self.init_state(optimizer, rngs, existing_state=old_state,
existing_best_state=old_best_state, model=model, param_transforms=param_transforms)
state, best_state = self.generate_states(
optimizer, rngs, old_state, old_best_state, model, param_transforms
)
self.init_state(state, best_state)
else:
self.state = train_state
self.best_state = train_state
Expand All @@ -383,7 +385,7 @@ def __init__(self,
def get_input_ones(self):
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}

def __init_fn(
def generate_states(
self,
optimizer: optax.GradientTransformation,
rngs: jax.random.PRNGKey,
Expand All @@ -392,11 +394,14 @@ def __init_fn(
model: nn.Module = None,
param_transforms: Callable = None
) -> Tuple[SimpleTrainState, SimpleTrainState]:
print("Generating states for SimpleTrainer")
rngs, subkey = jax.random.split(rngs)

if existing_state == None:
input_vars = self.get_input_ones()
params = model.init(subkey, **input_vars)
else:
params = existing_state['params']

state = SimpleTrainState.create(
apply_fn=model.apply,
Expand All @@ -415,17 +420,9 @@ def __init_fn(

def init_state(
self,
optimizer: optax.GradientTransformation,
rngs: jax.random.PRNGKey,
existing_state: dict = None,
existing_best_state: dict = None,
model: nn.Module = None,
param_transforms: Callable = None
state: SimpleTrainState,
best_state: SimpleTrainState,
):

state, best_state = self.__init_fn(
optimizer, rngs, existing_state, existing_best_state, model, param_transforms
)
self.best_loss = 1e9

if self.distributed_training:
Expand Down Expand Up @@ -665,7 +662,7 @@ def __init__(self,
self.model_output_transform = model_output_transform
self.unconditional_prob = unconditional_prob

def __init_fn(
def generate_states(
self,
optimizer: optax.GradientTransformation,
rngs: jax.random.PRNGKey,
Expand All @@ -674,6 +671,7 @@ def __init_fn(
model: nn.Module = None,
param_transforms: Callable = None
) -> Tuple[TrainState, TrainState]:
print("Generating states for DiffusionTrainer")
rngs, subkey = jax.random.split(rngs)

if existing_state == None:
Expand Down

0 comments on commit 198bfe3

Please sign in to comment.