Skip to content
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Move to Muscat=2.5.0 (for tests and examples support)
- Update repo configuration (actions: rely more on pypi dependencies, action versions)
- Rename types to remove `Type` from name of types: https://github.com/PLAID-lib/plaid/pull/164
- Refactored method names for improved clarity:
- `Dataset.from_tabular` → `Dataset.add_features_from_tabular`
- `Dataset.from_features_identifier` → `Dataset.extract_dataset_from_identifier`
- `Sample.from_features_identifier` → `Sample.extract_sample_from_identifier`
- Refactored all `tree` methods, fixtures, and examples to use `mesh` instead:
- Methods: `add_tree` → `add_mesh`, `del_tree` → `del_mesh`, `show_tree` → `show_mesh`,
`init_tree` → `init_mesh`, `link_tree` → `link_mesh`
- Fixtures: `sample_with_tree` → `sample_with_mesh`, `tree` → `mesh`, etc.
- Updated usage in tests, docs and examples
- Deprecated old `tree` methods for backward compatibility

### Fixes

Expand Down
8 changes: 4 additions & 4 deletions examples/containers/dataset_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def dprint(name: str, dictio: dict, end: str = "\n"):

# %%
# Add a CGNS tree structure to the Sample
sample_01.add_tree(cgns_mesh)
sample_01.add_mesh(cgns_mesh)
print(f"{sample_01 = }")

# %%
Expand Down Expand Up @@ -136,15 +136,15 @@ def dprint(name: str, dictio: dict, end: str = "\n"):
sample_03 = Sample()
sample_03.add_scalar("speed", np.random.randn())
sample_03.add_scalar("rotation", sample_01.get_scalar("rotation"))
sample_03.add_tree(cgns_mesh)
sample_03.add_mesh(cgns_mesh)

# Show Sample CGNS content
sample_03.show_tree()
sample_03.show_mesh()

# %%
# Add a field to the third empty Sample
sample_03.add_field("temperature", np.random.rand(5), "Zone", "Base_2_2")
sample_03.show_tree()
sample_03.show_mesh()

# %% [markdown]
# ### Get Sample data
Expand Down
17 changes: 9 additions & 8 deletions examples/containers/sample_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
# %%
# Print Sample util
def show_sample(sample: Sample):
print(f"sample = {sample}")
sample.show_tree()
print(f"{sample = }")
sample.show_mesh()
print(f"{sample.get_scalar_names() = }")
print(f"{sample.get_field_names() = }")

Expand Down Expand Up @@ -141,10 +141,10 @@ def show_sample(sample: Sample):

# %%
# Add the previously created CGNS tree to the sample
sample.add_tree(tree)
sample.add_mesh(tree)

# Display the Sample CGNS tree
sample.show_tree()
sample.show_mesh()

# %% [markdown]
# ### Set all meshes with their corresponding time step
Expand All @@ -160,6 +160,7 @@ def show_sample(sample: Sample):
new_sample_mult_mesh.set_meshes(meshes_dict)

print(f"{new_sample_mult_mesh.get_all_mesh_times() = }")
# new_sample_mult_mesh.show_mesh(1.)

# %% [markdown]
# ### Link tree from another sample
Expand Down Expand Up @@ -293,7 +294,7 @@ def show_sample(sample: Sample):
tmp_sample = Sample()

# Add the previously created CGNS tree in the Sample
tmp_sample.add_tree(tree)
tmp_sample.add_mesh(tree)

print("element connectivity = \n", f"{tmp_sample.get_elements()}")

Expand Down Expand Up @@ -386,7 +387,7 @@ def show_sample(sample: Sample):
print(f"{sample.get_all_mesh_times() = }")

# Add one CGNS tree at time 1.
sample.add_tree(tree, 1.0)
sample.add_mesh(tree, 1.0)

# After adding new tree
print(f"{sample.get_all_mesh_times() = }")
Expand Down Expand Up @@ -428,11 +429,11 @@ def show_sample(sample: Sample):
print(f"{sample.get_time_assignment() = }", end="\n\n")

# Print the tree at time 1.0
sample.show_tree() # == sample.show_tree(1.0)
sample.show_mesh() # == sample.show_tree(1.0)

# %%
# If time is specified as an argument in a function, it takes precedence over the default time.
sample.show_tree(0.0) # Print the tree at time 0.0 even if default time is 1.0
sample.show_mesh(0.0) # Print the mesh at time 0.0 even if default time is 1.0

# %% [markdown]
# ### Set and use default base and time in a Sample
Expand Down
2 changes: 1 addition & 1 deletion examples/convert_users_data_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def in_notebook():
# Add CGNS Meshe to samples with specific time steps
sample = Sample()

sample.add_tree(cgns_tree)
sample.add_mesh(cgns_tree)

# Add random scalar values to the sample
for sname in in_scalars_names:
Expand Down
2 changes: 1 addition & 1 deletion examples/pipelines/pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
# In this example, we aim to predict the ``mach`` field based on two input scalars ``angle_in`` and ``mach_out``, and the mesh node coordinates. To contain memory consumption, we restrict the dataset to the features required for this example:

# %%
dataset_train = dataset_train.from_features_identifier(all_feature_id)
dataset_train = dataset_train.extract_dataset_from_identifier(all_feature_id)
print("dataset_train:", dataset_train)
print("scalar names =", dataset_train.get_scalar_names())
print("field names =", dataset_train.get_field_names())
Expand Down
55 changes: 46 additions & 9 deletions src/plaid/containers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from plaid.containers.utils import check_features_size_homogeneity
from plaid.types import Array, Feature, FeatureIdentifier
from plaid.utils.base import DeprecatedError, ShapeError, generate_random_ASCII
from plaid.utils.deprecation import deprecated

logger = logging.getLogger(__name__)
logging.basicConfig(
Expand Down Expand Up @@ -652,7 +653,7 @@ def update_features_from_identifier(
)
return dataset

def from_features_identifier(
def extract_dataset_from_identifier(
self,
feature_identifiers: Union[FeatureIdentifier, list[FeatureIdentifier]],
) -> Self:
Expand All @@ -674,10 +675,24 @@ def from_features_identifier(
dataset.set_infos(copy.deepcopy(self.get_infos()))

for id in self.get_sample_ids():
extracted_sample = self[id].from_features_identifier(feature_identifiers)
extracted_sample = self[id].extract_sample_from_identifier(
feature_identifiers
)
dataset.add_sample(sample=extracted_sample, id=id)
return dataset

@deprecated(
"Use extract_dataset_from_identifier() instead",
version="0.1.8",
removal="0.2",
)
def from_features_identifier(
self,
feature_identifiers: Union[FeatureIdentifier, list[FeatureIdentifier]],
) -> Self:
"""DEPRECATED: Use extract_dataset_from_identifier() instead."""
return self.extract_dataset_from_identifier(feature_identifiers)

def get_tabular_from_homogeneous_identifiers(
self,
feature_identifiers: list[FeatureIdentifier],
Expand Down Expand Up @@ -734,22 +749,28 @@ def get_tabular_from_stacked_identifiers(

return tabular

def from_tabular(
def add_features_from_tabular(
self,
tabular: Array,
feature_identifiers: Union[FeatureIdentifier, list[FeatureIdentifier]],
restrict_to_features: bool = True,
) -> Self:
"""Generates a dataset from tabular data and feature_identifiers.
"""Add or update features in the dataset from tabular data using feature identifiers.

This method takes tabular data and applies it to the dataset, either by updating existing features
or adding new ones based on the provided feature identifiers. The method can either:
1. Extract only the specified features and return a new dataset with just those features (if restrict_to_features=True)
2. Update the specified features in the current dataset while keeping all other existing features (if restrict_to_features=False)

Parameters:
tabular (Array): of size (nb_sample, nb_features) or (nb_sample, nb_features, dim_feature) if dim_feature>1
feature_identifiers (dict or list of dict): One or more feature identifiers.
extract_features (bool, optional): If True, only returns the features from feature identifiers, otherwise keep the other features as well
feature_identifiers (dict or list of dict): One or more feature identifiers specifying which features to update/add.
restrict_to_features (bool, optional): If True, only returns the features from feature identifiers, otherwise keep the other features as well. Defaults to True.

Returns:
Self
A new dataset defined from tabular data and feature_identifiers.
Self: A new dataset with features updated/added from the tabular data. If restrict_to_features=True,
contains only the specified features. If restrict_to_features=False, contains all original
features plus the updated/added ones.

Raises:
AssertionError
Expand All @@ -765,7 +786,7 @@ def from_tabular(
features = {id: tabular[i] for i, id in enumerate(self.get_sample_ids())}

if restrict_to_features:
dataset = self.from_features_identifier(feature_identifiers)
dataset = self.extract_dataset_from_identifier(feature_identifiers)
dataset.update_features_from_identifier(
feature_identifiers=feature_identifiers,
features=features,
Expand All @@ -780,6 +801,22 @@ def from_tabular(

return dataset

@deprecated(
"Use add_features_from_tabular() instead",
version="0.1.8",
removal="0.2",
)
def from_tabular(
self,
tabular: Array,
feature_identifiers: Union[FeatureIdentifier, list[FeatureIdentifier]],
restrict_to_features: bool = True,
) -> Self:
"""DEPRECATED: Use add_features_from_tabular() instead."""
return self.add_features_from_tabular(
tabular, feature_identifiers, restrict_to_features
)

# -------------------------------------------------------------------------#
def add_info(self, cat_key: str, info_key: str, info: str) -> None:
"""Add information to the :class:`Dataset <plaid.containers.dataset.Dataset>`, overwriting existing information if there's a conflict.
Expand Down
Loading
Loading