Skip to content

Commit

Permalink
adapt tests to new tica
Browse files Browse the repository at this point in the history
  • Loading branch information
martinvoegele committed Aug 5, 2023
1 parent afab4d8 commit 594e63b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pensa/dimensionality/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def calculate_pca(data, dim=-1):

def pca_eigenvalues_plot(pca, num=12, plot_file=None):
"""
Plots the highest eigenvalues over the numberr of the principal components.
Plots the highest eigenvalues over the number of the principal components.
Parameters
----------
Expand Down
3 changes: 1 addition & 2 deletions pensa/dimensionality/tica.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import deeptime
from pyemma.util.contexts import settings
import MDAnalysis as mda
import matplotlib.pyplot as plt
from pensa.preprocessing import sort_coordinates, merge_and_sort_coordinates
Expand All @@ -14,7 +13,7 @@

def calculate_tica(data, dim=None, lag=10):
"""
Performs a PyEMMA TICA on the provided data.
Performs time-lagged independent component analysis (TICA) on the provided data.
Parameters
----------
Expand Down
50 changes: 25 additions & 25 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import unittest, gc, mdshare, pyemma, os, requests, scipy.spatial, scipy.spatial.distance, pytest, importlib, scipy.stats, sys, os
import scipy as sp
import MDAnalysis as mda
import unittest, gc, os
import matplotlib.pyplot as plt
import numpy as np

from pyemma.util.contexts import settings
from pensa.clusters import *
from pensa.comparison import *
from pensa.features import *
Expand Down Expand Up @@ -181,7 +178,8 @@ def setUp(self):
combined_data_tors = np.concatenate([bbt_a, bbt_b],0)

self.pca_combined = calculate_pca(combined_data_tors)
self.tica_combined = calculate_tica(combined_data_tors)
self.tica_bbt_a = calculate_tica(bbt_a)
self.tica_bbt_b = calculate_tica(bbt_b)

# -- PCA features
self.graph, self.corr = pca_features(
Expand Down Expand Up @@ -245,7 +243,7 @@ def setUp(self):


# ** ENSEMBLE COMPARISON **


# -- relative_entropy_analysis()
def test_01_relative_entropy_analysis(self):
Expand Down Expand Up @@ -319,12 +317,13 @@ def test_04_ssi_sem_analysis(self):


# -- sort_features()
def test_sort_features(self):
def test_05_sort_features(self):
sf = sort_features(self.names_bbtors, self.jsd_bbtors)
self.assertEqual(len(sf), 574)


# -- residue_visualization()
def test_residue_visualization(self):
def test_06_residue_visualization(self):
ref_filename = test_data_path + "/traj/condition-a_receptor.gro"
out_filename = "receptor_bbtors-deviations_tremd"
vis = residue_visualization(
Expand All @@ -340,8 +339,9 @@ def test_residue_visualization(self):
plt.close()
del vis


# -- distances_visualization()
def test_distances_visualization(self):
def test_07_distances_visualization(self):
matrix = distances_visualization(
self.names_bbdist, self.jsd_bbdist,
test_data_path + "/plots/receptor_jsd-bbdist.pdf",
Expand All @@ -360,7 +360,7 @@ def test_distances_visualization(self):


# -- calculate_pca()
def test_calculate_pca(self):
def test_08_calculate_pca(self):
self.assertEqual(len(self.pca_combined.mean), 460)
self.assertEqual(self.pca_combined.dim, -1)
self.assertEqual(self.pca_combined.skip, 0)
Expand All @@ -369,13 +369,13 @@ def test_calculate_pca(self):


# -- calculate_tica
def test_tica_combined(self):
self.assertEqual(self.tica_combined.lag, 10)
self.assertEqual(self.tica_combined.kinetic_map, True)
def test_09_calculate_tica(self):
self.assertEqual(self.tica_bbt_a.koopman_matrix.size, 841)
self.assertEqual(self.tica_bbt_b.koopman_matrix.size, 841)


# -- pca_eigenvalues_plot()
def test_pca_eigenvalues_plot(self):
def test_10_pca_eigenvalues_plot(self):
arr = pca_eigenvalues_plot(
self.pca_combined, num=12,
plot_file=test_data_path+'/plots/combined_tmr_pca_ev.pdf'
Expand All @@ -387,17 +387,17 @@ def test_pca_eigenvalues_plot(self):


# -- tica_eigenvalues_plot()
def test_tica_eigenvalues_plot(self):
def test_11_tica_eigenvalues_plot(self):
arr_1, arr_2 = tica_eigenvalues_plot(
self.tica_combined, num=12,
plot_file=test_data_path+'/plots/combined_tmr_tica_ev.pdf'
self.tica_bbt_a, num=12,
plot_file=test_data_path+'/plots/combined_tmr_tica_bbt_a_ev.pdf'
)
self.assertEqual(len(arr_1), 12)
self.assertEqual(len(arr_2), 12)


#-- pca_features()
def test_pca_features(self):
def test_12_pca_features(self):
self.assertEqual(len(self.graph), 3)
plt.close()
# -- Graph
Expand All @@ -408,21 +408,21 @@ def test_pca_features(self):


# -- tica_features()
def test_tica_features(self):
def test_13_tica_features(self):
test_feature = tica_features(
self.tica_combined, self.sim_a_tmr_feat['bb-torsions'], 3, 0.4
self.tica_bbt_a, self.sim_a_tmr_feat['bb-torsions'], 3, 0.4
)
self.assertEqual(len(test_feature), 460)


# -- sort_trajs_along_common_pc() + sort_traj_along_pc() + project_on_pc()
def test_sort_trajs_along_pc(self):
def test_14_sort_trajs_along_pc(self):
for ele in self.sort_common_traj:
self.assertEqual(len(ele), 3)
self.assertEqual(len(self.all_sort), 3)

# -- sort_trajs_along_common_tic()
def test_sort_trajs_along_common_tic(self):
def test_15_sort_trajs_along_common_tic(self):
sproj, sidx_data, sidx_traj = sort_trajs_along_common_tic(
self.sim_a_tmr_data['bb-torsions'],
self.sim_b_tmr_data['bb-torsions'],
Expand All @@ -438,20 +438,20 @@ def test_sort_trajs_along_common_tic(self):


# -- sort_traj_along_tic()
def test_sort_traj_along_tic(self):
def test_16_sort_traj_along_tic(self):
all_sort, _, _ = sort_traj_along_tic(
self.sim_a_tmr_data['bb-torsions'],
test_data_path + "/traj/condition-a_receptor.gro",
test_data_path + "/traj/condition-a_receptor.xtc",
test_data_path + "/pca/condition-a_receptor_by_tmr",
tica = self.tica_combined,
tica = self.tica_bbt_a,
num_ic=3
)
self.assertEqual(len(all_sort), 3)


# -- compare_projections()
def test_compare_projections(self):
def test_17_compare_projections(self):

self.assertEqual(len(self.val), 3)
self.assertEqual(len(self.val[0]), 2)
Expand Down

0 comments on commit 594e63b

Please sign in to comment.