Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct Shape of scdl output matirx #708

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __init__(
paginated_load_cutoff: int = 10_000,
load_block_row_size: int = 1_000_000,
feature_index_name="feature_id",
return_padded: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level, how does the code know what the pad size is? Or max size to pad to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the return get_row_padded for __getitem__

) -> None:
"""Instantiate the class.

Expand All @@ -264,6 +265,7 @@ def __init__(
self.data: Optional[np.ndarray] = None
self.row_index: Optional[np.ndarray] = None
self.row_index: Optional[np.ndarray] = None
self.return_padded = return_padded

# Metadata and attributes
self.metadata: Dict[str, int] = {}
Expand Down Expand Up @@ -699,7 +701,10 @@ def __len__(self):

def __getitem__(self, idx: int) -> torch.Tensor:
"""Get the row values located and index idx."""
return torch.from_numpy(np.stack(self.get_row(idx)[0]))
if self.return_padded:
return torch.from_numpy(self.get_row_padded(idx)[0])
else:
return torch.from_numpy(np.stack(self.get_row(idx)[0])), self.number_of_variables()

def number_of_variables(self) -> List[int]:
"""Get the number of features in every entry in the dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import torch


def collate_sparse_matrix_batch(batch: list[torch.Tensor]) -> torch.Tensor:
def collate_sparse_matrix_batch(batch: tuple) -> torch.Tensor:
"""Collate function to create a batch out of sparse tensors.

This is necessary to collate sparse matrices of various lengths.
Expand All @@ -27,14 +29,17 @@ def collate_sparse_matrix_batch(batch: list[torch.Tensor]) -> torch.Tensor:
Returns:
The tensors collated into a CSR (Compressed Sparse Row) Format.
"""
# evey batch is a tuple of (sparse_matrix, n_features)
# we use the first batch to get the number of features
n_features = batch[0][1][0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining the hardcoded indexes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

batch = [x[0] for x in batch]
batch_rows = torch.cumsum(
torch.tensor([0] + [sparse_representation.shape[1] for sparse_representation in batch]), dim=0
)
batch_cols = torch.cat([sparse_representation[1] for sparse_representation in batch]).to(torch.int32)
batch_values = torch.cat([sparse_representation[0] for sparse_representation in batch])
if len(batch_cols) == 0:
max_pointer = 0
else:
max_pointer = int(batch_cols.max().item() + 1)
batch_sparse_tensor = torch.sparse_csr_tensor(batch_rows, batch_cols, batch_values, size=(len(batch), max_pointer))

with warnings.catch_warnings():
warnings.simplefilter("ignore")
batch_sparse_tensor = torch.sparse_csr_tensor(batch_rows, batch_cols, batch_values, size=(len(batch), n_features))
return batch_sparse_tensor