Skip to content

Commit

Permalink
from XFitting to PCAFit, PreFit & FullFit
Browse files Browse the repository at this point in the history
  • Loading branch information
MilagrosMarin committed Mar 20, 2024
1 parent d201bc4 commit 142ad44
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
3 changes: 2 additions & 1 deletion element_moseq/moseq_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
)

from element_interface.utils import find_full_path
from .readers.kpms_reader import load_kpms_dj_config, generate_kpms_dj_config
from .readers.kpms_reader import load_kpms_dj_config

from keypoint_moseq import (
load_checkpoint,
load_keypoints,
Expand Down
31 changes: 16 additions & 15 deletions element_moseq/moseq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from element_interface.utils import find_full_path
from .readers.kpms_reader import generate_kpms_dj_config, load_kpms_dj_config

from keypoint_moseq import (
setup_project,
load_config,
Expand Down Expand Up @@ -339,7 +340,7 @@ def make(self, key):


@schema
class PCAFitting(dj.Computed):
class PCAFit(dj.Computed):
"""Automated fitting of the PCA model.
Attributes:
Expand Down Expand Up @@ -402,14 +403,14 @@ class LatentDimension(dj.Imported):
The analysis aims to select each of the components that explain the 90% of variance (fixed threshold).
Attributes:
PCAFitting (foreign key) : PCAFitting Key.
PCAFit (foreign key) : PCAFit Key.
variance_percentage (float) : Variance threshold. Fixed value to 90%.
latent_dimension (int) : Number of principal components required to explain the specified variance.
latent_dim_desc (varchar) : Automated description of the computation result.
"""

definition = """
-> PCAFitting # PCAFitting Key
-> PCAFit # PCAFit Key
---
variance_percentage : float # Variance threshold. Fixed value to 0.9
latent_dimension : int # Number of principal components required to explain the specified variance.
Expand All @@ -421,7 +422,7 @@ def make(self, key):
Make function to compute and store the latent dimensions that explain a 90% variance threshold.
Args:
key (dict): PCAFitting Key.
key (dict): PCAFit Key.
Raises:
Expand Down Expand Up @@ -474,15 +475,15 @@ class PreFittingTask(dj.Manual):
"""Table to specify the parameters for the pre-fitting (AR-HMM) of the model.
Attributes:
PCAFitting (foreign key) : PCA fitting task.
PCAFit (foreign key) : PCAFit task.
pre_latent_dim (int) : Number of latent dimensions to use for the model pre-fitting.
pre_kappa (int) : Kappa value to use for the model pre-fitting.
pre_num_iterations (int) : Number of Gibbs sampling iterations to run in the model pre-fitting.
pre_fitting_desc(varchar) : User-defined description of the pre-fitting task.
"""

definition = """
-> PCAFitting # PCAFitting Key
-> PCAFit # PCAFit Key
pre_latent_dim : int # Number of latent dimensions to use for the model pre-fitting
pre_kappa : int # Kappa value to use for the model pre-fitting
pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting.
Expand All @@ -492,7 +493,7 @@ class PreFittingTask(dj.Manual):


@schema
class PreFitting(dj.Computed):
class PreFit(dj.Computed):
"""Automated computation to fit a AR-HMM model.
Attributes:
Expand Down Expand Up @@ -527,7 +528,7 @@ def make(self, key):
5. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed.
6. Update the model dict with the selected kappa for the AR-HMM fitting
7. Fit the AR-HMM model using the `pre_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint
8. Calculate the duration of the model fitting computation and insert it in the `PreFitting` table
8. Calculate the duration of the model fitting computation and insert it in the `PreFit` table
"""

kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir")
Expand Down Expand Up @@ -590,7 +591,7 @@ class FullFittingTask(dj.Manual):
"""Table to specify the parameters for the full fitting of the model. The full model will generally require a lower value of kappa to yield the same target syllable durations.
Attributes:
PCAFitting (foreign key) : PCAFitting Key.
PCAFit (foreign key) : PCAFit Key.
full_latent_dim (int) : Number of latent dimensions to use for the model full fitting.
full_kappa (int) : Kappa value to use for the model full fitting.
full_num_iterations (int) : Number of Gibbs sampling iterations to run in the model full fitting.
Expand All @@ -599,7 +600,7 @@ class FullFittingTask(dj.Manual):
"""

definition = """
-> PCAFitting # PCAFitting Key
-> PCAFit # PCAFit Key
full_latent_dim : int # Number of latent dimensions to use for the model full fitting
full_kappa : int # Kappa value to use for the model full fitting
full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting.
Expand All @@ -609,20 +610,20 @@ class FullFittingTask(dj.Manual):


@schema
class FullFitting(dj.Computed):
class FullFit(dj.Computed):
"""Automated computation to fit the full model.
Attributes:
FullFittingTask (foreign key) : FullFittingTask Key.
model_name : varchar(100) # Name of the full-fitted model (output_dir/model_name)
full_fitting_duration (float) : Time duration (seconds) of the full fitting model
full_fit_duration (float) : Time duration (seconds) of the full fitting model
"""

definition = """
-> FullFittingTask # FullFittingTask Key
---
model_name : varchar(100) # Name of the full-fitted model (output_dir/model_name)
full_fitting_duration=NULL : float # Time duration (seconds) of the full fitting model
full_fit_duration=NULL : float # Time duration (seconds) of the full fitting model
"""

def make(self, key):
Expand All @@ -645,7 +646,7 @@ def make(self, key):
8. Reindex syllable labels by their frequency in the most recent model snapshot in a checkpoint file. \
This function permutes the states and parameters of a saved checkpoint so that syllables are labeled \
in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on).
8. Calculate the duration of the model fitting computation and insert it in the `PreFitting` table
8. Calculate the duration of the model fitting computation and insert it in the `PreFit` table
"""

kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir")
Expand Down Expand Up @@ -698,6 +699,6 @@ def make(self, key):
kpms_project_output_dir.relative_to(get_kpms_processed_data_dir())
/ model_name
).as_posix(),
"full_fitting_duration": duration_seconds,
"full_fit_duration": duration_seconds,
}
)

0 comments on commit 142ad44

Please sign in to comment.