Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] rework TimeSeriesDataSet using LightningDataModule - experimental #1766

Open
fkiraly opened this issue Feb 10, 2025 · 5 comments
Open
Assignees
Labels
API design API design & software architecture enhancement New feature or request

Comments

@fkiraly
Copy link
Collaborator

fkiraly commented Feb 10, 2025

Umbrella issue for pytorch-forecasting 2.0 design: #1736

In sktime/enhancement-proposals#39, @phoeenniixx suggested a LightningDataModule based design for the end state of dsipts and pytorch-forecasting 2.0.

As an work item, this implies a rework of the TimeSeriesDataSet using LightningDataModule, covering layers D1 and D2 (referencing the EP), with the D1 layer based on a refined design following #1757 (but simpler).

@phoeenniixx agreed to give this a go as part of an experimental PR.

@fkiraly fkiraly added API design API design & software architecture enhancement New feature or request labels Feb 10, 2025
@phoeenniixx
Copy link

with the D1 layer based on a refined design following #1757 (but simpler).

Should we use the "dict" implementation for metadata or keep it separate as it is now?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Feb 11, 2025

For D1, I think your dict idea is better than the list based format I suggested, because it allows to keep the number of metadata fields small. That is, your original suggestion, but with my subsequent modification of having even less dicts and metadata fields, is what I would go for.

For the start we can assume everything is float and future-known though, I estimate otherwise there would be a lot of boring boilerplate in handling the different column types etc.

The reason for that is, I think we should get to an end-to-end design quickly and see how it looks like and how/whether it works, because we might modify or even abandon it. The work on the boilerplate would then be lost.

Whereas, if this proves to be the way to go, it is still easy to add it on top.

@phoeenniixx
Copy link

Few problems I found with this approach:

  • torch dataloader expects Dataset class and not LightningDataModule
  • LightningDataModule is meant to do just the data "handling" and to pass it to the dataloader we need to wrap it around Dataset

The to_dataloader function of TimeSeriesDataset passes self to the the dataloader, while if we use just LightningDataModule, it is not possible.

Proposed solutions:

  1. We can add one more layer after LightningDataModule, that is a Dataset layer
class TimeSeriesDataset(Dataset):
   def __init__(self, datamodule: 'DecoderEncoderDataModule'):

       self.datamodule = datamodule
       self.tsd = datamodule.tsd  # Preprocessed TimeSeries data

   def __len__(self):
       return len(self.tsd)

   def __getitem__(self, idx):
       # Fetch raw sample from TimeSeries
       batch = self.tsd[idx]

       # Apply all transformations inside the datamodule
       transformed_batch = self.datamodule.transformation(batch)

       return transformed_batch
  1. We can remove the LightningDataModule layer completely and can the other proposed approach, of passing the metadata in the __init__ of the D2:
class DecoderEncoderData(Dataset):
   def __init__(self, tsd: PandasTSDataSet, **params):
       self.tsd = tsd  # Store dataset reference

       # Access metadata from dataset (D1)
       self.metadata = tsd.get_metadata()

   def __getitem__(self, idx):
       sample = self.tsd[idx]  
       # other required additions to ``sample``
       return sample

@fkiraly
Copy link
Collaborator Author

fkiraly commented Feb 11, 2025

could you double check how the LightningDataModule interacts with a LighningModel, in the vanilla vignette? Is the additional layer between them really required? That would surprise me.

I feel the LightningDataModule is needed to satisfy the requirement of dsipts to be able to specify train, validation splits.

@phoeenniixx
Copy link

phoeenniixx commented Feb 11, 2025

In this tutorial, the dataloaders in datamodule get only dataset class and not the datamodule class itself,

LightningDataModule only defines how to train, validate, and test the model. It does not handle dataset indexing, transformations, or batching—this is the job of DataLoader.

DataLoader requires a Dataset object to properly iterate over data, as it has __getitem__ and __len__ implementations.

More clear example:

class MNISTDataModule(LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor()])
        MNIST(os.getcwd(), train=True, download=False, transform=transform)
        MNIST(os.getcwd(), train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

source : https://pytorch-lightning.readthedocs.io/en/0.10.0/introduction_guide.html#the-engineering

here you can see, they pass the datasets and not the module itself.

DataModules are useful while training and testing as then you just pass the model and the module and everything is handled being the curtain.

dm = MNISTDataModule()
model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model, dm)

Here LitMNIST is a lightning module

Difference between using and not using data modules of lightning [Source]

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API design API design & software architecture enhancement New feature or request
Projects
Development

No branches or pull requests

2 participants