diff --git a/.eggs/README.txt b/.eggs/README.txt new file mode 100644 index 0000000..5d01668 --- /dev/null +++ b/.eggs/README.txt @@ -0,0 +1,6 @@ +This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins. + +This directory caches those eggs to prevent repeated downloads. + +However, it is safe to delete this directory. + diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/LICENSE b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/LICENSE new file mode 100644 index 0000000..353924b --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/LICENSE @@ -0,0 +1,19 @@ +Copyright Jason R. Coombs + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to +deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/PKG-INFO b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/PKG-INFO new file mode 100644 index 0000000..75c36d6 --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/PKG-INFO @@ -0,0 +1,185 @@ +Metadata-Version: 2.1 +Name: pytest-runner +Version: 6.0.1 +Summary: Invoke py.test as distutils command with dependency resolution +Home-page: https://github.com/pytest-dev/pytest-runner/ +Author: Jason R. Coombs +Author-email: jaraco@jaraco.com +Classifier: Development Status :: 7 - Inactive +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Framework :: Pytest +Requires-Python: >=3.7 +License-File: LICENSE +Provides-Extra: docs +Requires-Dist: sphinx ; extra == 'docs' +Requires-Dist: jaraco.packaging >=9 ; extra == 'docs' +Requires-Dist: rst.linker >=1.9 ; extra == 'docs' +Requires-Dist: jaraco.tidelift >=1.4 ; extra == 'docs' +Provides-Extra: testing +Requires-Dist: pytest >=6 ; extra == 'testing' +Requires-Dist: pytest-checkdocs >=2.4 ; extra == 'testing' +Requires-Dist: pytest-flake8 ; extra == 'testing' +Requires-Dist: pytest-cov ; extra == 'testing' +Requires-Dist: pytest-enabler >=1.0.1 ; extra == 'testing' +Requires-Dist: pytest-virtualenv ; extra == 'testing' +Requires-Dist: types-setuptools ; extra == 'testing' +Requires-Dist: pytest-black >=0.3.7 ; (platform_python_implementation != "PyPy") and extra == 'testing' +Requires-Dist: pytest-mypy >=0.9.1 ; (platform_python_implementation != "PyPy") and extra == 'testing' + +.. image:: https://img.shields.io/pypi/v/pytest-runner.svg + :target: `PyPI link`_ + +.. image:: https://img.shields.io/pypi/pyversions/pytest-runner.svg + :target: `PyPI link`_ + +.. _PyPI link: https://pypi.org/project/pytest-runner + +.. image:: https://github.com/pytest-dev/pytest-runner/workflows/tests/badge.svg + :target: https://github.com/pytest-dev/pytest-runner/actions?query=workflow%3A%22tests%22 + :alt: tests + +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black + :alt: Code style: Black + +.. .. image:: https://readthedocs.org/projects/skeleton/badge/?version=latest +.. :target: https://skeleton.readthedocs.io/en/latest/?badge=latest + +.. image:: https://img.shields.io/badge/skeleton-2022-informational + :target: https://blog.jaraco.com/skeleton + +.. image:: https://tidelift.com/badges/package/pypi/pytest-runner + :target: https://tidelift.com/subscription/pkg/pypi-pytest-runner?utm_source=pypi-pytest-runner&utm_medium=readme + +Setup scripts can use pytest-runner to add setup.py test support for pytest +runner. + +Deprecation Notice +================== + +pytest-runner depends on deprecated features of setuptools and relies on features that break security +mechanisms in pip. For example 'setup_requires' and 'tests_require' bypass ``pip --require-hashes``. +See also `pypa/setuptools#1684 `_. + +It is recommended that you: + +- Remove ``'pytest-runner'`` from your ``setup_requires``, preferably removing the ``setup_requires`` option. +- Remove ``'pytest'`` and any other testing requirements from ``tests_require``, preferably removing the ``tests_requires`` option. +- Select a tool to bootstrap and then run tests such as tox. + +Usage +===== + +- Add 'pytest-runner' to your 'setup_requires'. Pin to '>=2.0,<3dev' (or + similar) to avoid pulling in incompatible versions. +- Include 'pytest' and any other testing requirements to 'tests_require'. +- Invoke tests with ``setup.py pytest``. +- Pass ``--index-url`` to have test requirements downloaded from an alternate + index URL (unnecessary if specified for easy_install in setup.cfg). +- Pass additional py.test command-line options using ``--addopts``. +- Set permanent options for the ``python setup.py pytest`` command (like ``index-url``) + in the ``[pytest]`` section of ``setup.cfg``. +- Set permanent options for the ``py.test`` run (like ``addopts`` or ``pep8ignore``) in the ``[pytest]`` + section of ``pytest.ini`` or ``tox.ini`` or put them in the ``[tool:pytest]`` + section of ``setup.cfg``. See `pytest issue 567 + `_. +- Optionally, set ``test=pytest`` in the ``[aliases]`` section of ``setup.cfg`` + to cause ``python setup.py test`` to invoke pytest. + +Example +======= + +The most simple usage looks like this in setup.py:: + + setup( + setup_requires=[ + 'pytest-runner', + ], + tests_require=[ + 'pytest', + ], + ) + +Additional dependencies require to run the tests (e.g. mock or pytest +plugins) may be added to tests_require and will be downloaded and +required by the session before invoking pytest. + +Follow `this search on github +`_ +for examples of real-world usage. + +Standalone Example +================== + +This technique is deprecated - if you have standalone scripts +you wish to invoke with dependencies, `use pip-run +`_. + +Although ``pytest-runner`` is typically used to add pytest test +runner support to maintained packages, ``pytest-runner`` may +also be used to create standalone tests. Consider `this example +failure `_, +reported in `jsonpickle #117 +`_ +or `this MongoDB test +`_ +demonstrating a technique that works even when dependencies +are required in the test. + +Either example file may be cloned or downloaded and simply run on +any system with Python and Setuptools. It will download the +specified dependencies and run the tests. Afterward, the the +cloned directory can be removed and with it all trace of +invoking the test. No other dependencies are needed and no +system configuration is altered. + +Then, anyone trying to replicate the failure can do so easily +and with all the power of pytest (rewritten assertions, +rich comparisons, interactive debugging, extensibility through +plugins, etc). + +As a result, the communication barrier for describing and +replicating failures is made almost trivially low. + +Considerations +============== + +Conditional Requirement +----------------------- + +Because it uses Setuptools setup_requires, pytest-runner will install itself +on every invocation of setup.py. In some cases, this causes delays for +invocations of setup.py that will never invoke pytest-runner. To help avoid +this contingency, consider requiring pytest-runner only when pytest +is invoked:: + + needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) + pytest_runner = ['pytest-runner'] if needs_pytest else [] + + # ... + + setup( + #... + setup_requires=[ + #... (other setup requirements) + ] + pytest_runner, + ) + +For Enterprise +============== + +Available as part of the Tidelift Subscription. + +This project and the maintainers of thousands of other packages are working with Tidelift to deliver one enterprise subscription that covers all of the open source you use. + +`Learn more `_. + +Security Contact +================ + +To report a security vulnerability, please use the +`Tidelift security contact `_. +Tidelift will coordinate the fix and disclosure. diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/RECORD b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/RECORD new file mode 100644 index 0000000..d7ff24d --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/RECORD @@ -0,0 +1,7 @@ +ptr/__init__.py,sha256=0UfzhCooVgCNTBwVEOPOVGEPck4pnl_6PTfsC-QzNGM,6730 +pytest_runner-6.0.1.dist-info/LICENSE,sha256=2z8CRrH5J48VhFuZ_sR4uLUG63ZIeZNyL4xuJUKF-vg,1050 +pytest_runner-6.0.1.dist-info/METADATA,sha256=Ho3FvAFjFHeY5OQ64WFzkLigFaIpuNr4G3uSmOk3nho,7319 +pytest_runner-6.0.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92 +pytest_runner-6.0.1.dist-info/entry_points.txt,sha256=BqezBqeO63XyzSYmHYE58gKEFIjJUd-XdsRQkXHy2ig,58 +pytest_runner-6.0.1.dist-info/top_level.txt,sha256=DPzHbWlKG8yq8EOD5UgEvVNDWeJRPyimrwfShwV6Iuw,4 +pytest_runner-6.0.1.dist-info/RECORD,, diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/WHEEL b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/WHEEL new file mode 100644 index 0000000..98c0d20 --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.42.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/entry_points.txt b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/entry_points.txt new file mode 100644 index 0000000..0860670 --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/entry_points.txt @@ -0,0 +1,3 @@ +[distutils.commands] +ptr = ptr:PyTest +pytest = ptr:PyTest diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/requires.txt b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/requires.txt new file mode 100644 index 0000000..1535188 --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/requires.txt @@ -0,0 +1,17 @@ + +[docs] +sphinx +jaraco.packaging>=9 +rst.linker>=1.9 +jaraco.tidelift>=1.4 + +[testing] +pytest>=6 +pytest-checkdocs>=2.4 +pytest-flake8 +pytest-cov +pytest-enabler>=1.0.1 +pytest-virtualenv +types-setuptools +pytest-black>=0.3.7 +pytest-mypy>=0.9.1 diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/top_level.txt b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/top_level.txt new file mode 100644 index 0000000..e9148ae --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/EGG-INFO/top_level.txt @@ -0,0 +1 @@ +ptr diff --git a/.eggs/pytest_runner-6.0.1-py3.8.egg/ptr/__init__.py b/.eggs/pytest_runner-6.0.1-py3.8.egg/ptr/__init__.py new file mode 100644 index 0000000..41192fa --- /dev/null +++ b/.eggs/pytest_runner-6.0.1-py3.8.egg/ptr/__init__.py @@ -0,0 +1,216 @@ +""" +Implementation +""" + +import os as _os +import shlex as _shlex +import contextlib as _contextlib +import sys as _sys +import operator as _operator +import itertools as _itertools +import warnings as _warnings + +import pkg_resources +import setuptools.command.test as orig +from setuptools import Distribution + + +@_contextlib.contextmanager +def _save_argv(repl=None): + saved = _sys.argv[:] + if repl is not None: + _sys.argv[:] = repl + try: + yield saved + finally: + _sys.argv[:] = saved + + +class CustomizedDist(Distribution): + + allow_hosts = None + index_url = None + + def fetch_build_egg(self, req): + """Specialized version of Distribution.fetch_build_egg + that respects respects allow_hosts and index_url.""" + from setuptools.command.easy_install import easy_install + + dist = Distribution({'script_args': ['easy_install']}) + dist.parse_config_files() + opts = dist.get_option_dict('easy_install') + keep = ( + 'find_links', + 'site_dirs', + 'index_url', + 'optimize', + 'site_dirs', + 'allow_hosts', + ) + for key in list(opts): + if key not in keep: + del opts[key] # don't use any other settings + if self.dependency_links: + links = self.dependency_links[:] + if 'find_links' in opts: + links = opts['find_links'][1].split() + links + opts['find_links'] = ('setup', links) + if self.allow_hosts: + opts['allow_hosts'] = ('test', self.allow_hosts) + if self.index_url: + opts['index_url'] = ('test', self.index_url) + install_dir_func = getattr(self, 'get_egg_cache_dir', _os.getcwd) + install_dir = install_dir_func() + cmd = easy_install( + dist, + args=["x"], + install_dir=install_dir, + exclude_scripts=True, + always_copy=False, + build_directory=None, + editable=False, + upgrade=False, + multi_version=True, + no_report=True, + user=False, + ) + cmd.ensure_finalized() + return cmd.easy_install(req) + + +class PyTest(orig.test): + """ + >>> import setuptools + >>> dist = setuptools.Distribution() + >>> cmd = PyTest(dist) + """ + + user_options = [ + ('extras', None, "Install (all) setuptools extras when running tests"), + ( + 'index-url=', + None, + "Specify an index url from which to retrieve dependencies", + ), + ( + 'allow-hosts=', + None, + "Whitelist of comma-separated hosts to allow " + "when retrieving dependencies", + ), + ( + 'addopts=', + None, + "Additional options to be passed verbatim to the pytest runner", + ), + ] + + def initialize_options(self): + self.extras = False + self.index_url = None + self.allow_hosts = None + self.addopts = [] + self.ensure_setuptools_version() + + @staticmethod + def ensure_setuptools_version(): + """ + Due to the fact that pytest-runner is often required (via + setup-requires directive) by toolchains that never invoke + it (i.e. they're only installing the package, not testing it), + instead of declaring the dependency in the package + metadata, assert the requirement at run time. + """ + pkg_resources.require('setuptools>=27.3') + + def finalize_options(self): + if self.addopts: + self.addopts = _shlex.split(self.addopts) + + @staticmethod + def marker_passes(marker): + """ + Given an environment marker, return True if the marker is valid + and matches this environment. + """ + return ( + not marker + or not pkg_resources.invalid_marker(marker) + and pkg_resources.evaluate_marker(marker) + ) + + def install_dists(self, dist): + """ + Extend install_dists to include extras support + """ + return _itertools.chain( + orig.test.install_dists(dist), self.install_extra_dists(dist) + ) + + def install_extra_dists(self, dist): + """ + Install extras that are indicated by markers or + install all extras if '--extras' is indicated. + """ + extras_require = dist.extras_require or {} + + spec_extras = ( + (spec.partition(':'), reqs) for spec, reqs in extras_require.items() + ) + matching_extras = ( + reqs + for (name, sep, marker), reqs in spec_extras + # include unnamed extras or all if self.extras indicated + if (not name or self.extras) + # never include extras that fail to pass marker eval + and self.marker_passes(marker) + ) + results = list(map(dist.fetch_build_eggs, matching_extras)) + return _itertools.chain.from_iterable(results) + + @staticmethod + def _warn_old_setuptools(): + msg = ( + "pytest-runner will stop working on this version of setuptools; " + "please upgrade to setuptools 30.4 or later or pin to " + "pytest-runner < 5." + ) + ver_str = pkg_resources.get_distribution('setuptools').version + ver = pkg_resources.parse_version(ver_str) + if ver < pkg_resources.parse_version('30.4'): + _warnings.warn(msg) + + def run(self): + """ + Override run to ensure requirements are available in this session (but + don't install them anywhere). + """ + self._warn_old_setuptools() + dist = CustomizedDist() + for attr in 'allow_hosts index_url'.split(): + setattr(dist, attr, getattr(self, attr)) + for attr in ( + 'dependency_links install_requires tests_require extras_require ' + ).split(): + setattr(dist, attr, getattr(self.distribution, attr)) + installed_dists = self.install_dists(dist) + if self.dry_run: + self.announce('skipping tests (dry run)') + return + paths = map(_operator.attrgetter('location'), installed_dists) + with self.paths_on_pythonpath(paths): + with self.project_on_sys_path(): + return self.run_tests() + + @property + def _argv(self): + return ['pytest'] + self.addopts + + def run_tests(self): + """ + Invoke pytest, replacing argv. Return result code. + """ + with _save_argv(_sys.argv[:1] + self.addopts): + result_code = __import__('pytest').main() + if result_code: + raise SystemExit(result_code) diff --git a/spacetimeformer/TimeSeriesDataset.py b/spacetimeformer/TimeSeriesDataset.py new file mode 100644 index 0000000..8c01d3c --- /dev/null +++ b/spacetimeformer/TimeSeriesDataset.py @@ -0,0 +1,45 @@ +import os +import pandas as pd +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader + +class TimeSeriesDataset(Dataset): + def __init__(self, data_folder, context_length, forecast_length): + self.context_length = context_length + self.forecast_length = forecast_length + self.data_files = self.load_csv_files(data_folder) + self.cumulative_lengths = self.get_cumulative_lengths(self.data_files) + + def load_csv_files(self, folder): + all_files = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith('.csv')] + return [pd.read_csv(file, index_col=0).values for file in all_files] # Treat the first column as index + + + def get_cumulative_lengths(self, data_files): + lengths = [len(file) - (self.context_length + self.forecast_length) for file in data_files] + return np.cumsum([0] + lengths) + + def __len__(self): + return self.cumulative_lengths[-1] + + def __getitem__(self, idx): + file_index = np.where(self.cumulative_lengths > idx)[0][0] - 1 + within_file_idx = idx - self.cumulative_lengths[file_index] + context = self.data_files[file_index][within_file_idx:within_file_idx+self.context_length] + forecast = self.data_files[file_index][within_file_idx+self.context_length:within_file_idx+self.context_length+self.forecast_length] + return torch.tensor(context, dtype=torch.float), torch.tensor(forecast, dtype=torch.float) +# Create DataLoaders for each dataset +train_dataset = TimeSeriesDataset(data_folder='./data/train', context_length=10, forecast_length=10) +train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) + +test_dataset = TimeSeriesDataset(data_folder='./data/test', context_length=10, forecast_length=10) +test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) + +oos_dataset = TimeSeriesDataset(data_folder='./data/oos', context_length=10, forecast_length=10) +oos_dataloader = DataLoader(oos_dataset, batch_size=32, shuffle=False) + +# Example of iterating over a DataLoader +for context, forecast in train_dataloader: + # Model training code here + 1+1 diff --git a/spacetimeformer/images/MNIST/raw/t10k-images-idx3-ubyte b/spacetimeformer/images/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/spacetimeformer/images/MNIST/raw/t10k-images-idx3-ubyte.gz b/spacetimeformer/images/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..5ace8ea Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/spacetimeformer/images/MNIST/raw/t10k-labels-idx1-ubyte b/spacetimeformer/images/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/spacetimeformer/images/MNIST/raw/t10k-labels-idx1-ubyte.gz b/spacetimeformer/images/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..a7e1415 Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/spacetimeformer/images/MNIST/raw/train-images-idx3-ubyte b/spacetimeformer/images/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/train-images-idx3-ubyte differ diff --git a/spacetimeformer/images/MNIST/raw/train-images-idx3-ubyte.gz b/spacetimeformer/images/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..b50e4b6 Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/spacetimeformer/images/MNIST/raw/train-labels-idx1-ubyte b/spacetimeformer/images/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/spacetimeformer/images/MNIST/raw/train-labels-idx1-ubyte.gz b/spacetimeformer/images/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..707a576 Binary files /dev/null and b/spacetimeformer/images/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/spacetimeformer/train.py b/spacetimeformer/train.py index b26b714..9ebce9e 100644 --- a/spacetimeformer/train.py +++ b/spacetimeformer/train.py @@ -9,6 +9,8 @@ import torch import spacetimeformer as stf +from TimeSeriesDataset import TimeSeriesDataset +from torch.utils.data import DataLoader _MODELS = ["spacetimeformer", "mtgnn", "heuristic", "lstm", "lstnet", "linear", "s4"] @@ -16,6 +18,7 @@ "asos", "metr-la", "pems-bay", + "stocks", "exchange", "precip", "toy2", @@ -140,6 +143,10 @@ def create_model(config): x_dim = 6 yc_dim = 137 yt_dim = 137 + elif config.dset == "stocks": + x_dim = 95 + yc_dim = 95 # Can reduce to specific features. i.e you could forecast only 'Close' (yc_dim=1) + yt_dim = 95 elif config.dset == "exchange": x_dim = 6 yc_dim = 8 @@ -640,6 +647,43 @@ def create_dset(config): "New Zealand", "Singapore", ] + + elif config.dset == "stocks": + if config.phase == "train": + data_path = './data/train' + elif config.phase == "test": + data_path = './data/test' + else: + data_path = './data/oos' # Assume 'oos' is for out-of-sample or validation + + # data_module = TimeSeriesDataset(data_folder=data_path, context_length=config.context_points, forecast_length=config.target_points) + # data_loader = DataLoader(data_module, batch_size=config.batch_size, shuffle=True) + data_module = TimeSeriesDataset(data_folder=data_path, context_length=config.context_points, forecast_length=config.target_points) + data_loader = DataLoader(data_module, batch_size=config.batch_size, shuffle=True if config.phase == "train" else False) + target_cols = ['open', 'high', 'low', 'Close', 'vclose', 'vopen', 'vhigh', 'vlow', + 'VIX', 'SPY', 'TNX', 'rsi14', 'rsi9', 'rsi24', 'MACD5355macddiff', + 'MACD5355macddiffslope', 'MACD5355macd', 'MACD5355macdslope', + 'MACD5355macdsig', 'MACD5355macdsigslope', 'MACD12269macddiff', + 'MACD12269macddiffslope', 'MACD12269macd', 'MACD12269macdslope', + 'MACD12269macdsig', 'MACD12269macdsigslope', 'lowTail', 'highTail', + 'openTail', 'IntradayBar', 'IntradayRange', 'CloseOverSMA5', + 'CloseOverSMA10', 'CloseOverSMA12', 'CloseOverSMA20', 'CloseOverSMA30', + 'CloseOverSMA65', 'CloseOverSMA50', 'CloseOverSMA100', + 'CloseOverSMA200', 'VolOverSMA5', 'VolOverSMA10', 'VolOverSMA12', + 'VolOverSMA20', 'VolOverSMA30', 'VolOverSMA65', 'VolOverSMA50', + 'VolOverSMA100', 'VolOverSMA200', 'Ret1day', 'Ret4day', 'Ret8day', + 'Ret12day', 'Ret24day', 'Ret72day', 'Ret240day', 'RSC', 'bands_l', + 'bands_u', 'ADX', 'cloudA', 'cloudB', 'closeVsIchA', 'closeVsIchB', + 'IchAvIchB', 'CondVol_1', 'CondVol_4', 'CondVol_8', 'CondVol_12', + 'CondVol_24', 'CondVol_72', 'CondVol_240', 'CV1vCV4', 'CV4vCV8', + 'CV8vCV12', 'CV12vCV24', 'CV8vCV24', 'CV24vCV240', 'RSC_VIX', + 'RSC_VIX_IV', 'RSC_VIX_real', 'RSC_VIX_IV_real', 'RSC_IV_gar', + 'close_spy_corr22', 'close_tnx_corr22', 'vclose_VIX_corr22', + 'garch_IV_corr22', 'close_spy_corr65', 'close_tnx_corr65', + 'vclose_VIX_corr65', 'garch_IV_corr65', 'close_spy_corr252', + 'close_tnx_corr252', 'vclose_VIX_corr252', 'garch_IV_corr252'] + + elif config.dset == "traffic": if data_path == "auto": data_path = "./data/traffic.csv" @@ -671,17 +715,28 @@ def create_dset(config): INV_SCALER = dset.reverse_scaling SCALER = dset.apply_scaling NULL_VAL = None + if config.dset =='stocks': - return ( - DATA_MODULE, - INV_SCALER, - SCALER, - NULL_VAL, - PLOT_VAR_IDXS, - PLOT_VAR_NAMES, - PAD_VAL, - ) - + return ( + data_loader, + INV_SCALER, + SCALER, + NULL_VAL, + PLOT_VAR_IDXS, + PLOT_VAR_NAMES, + PAD_VAL, + ) + else: + return ( + DATA_MODULE, + INV_SCALER, + SCALER, + NULL_VAL, + PLOT_VAR_IDXS, + PLOT_VAR_NAMES, + PAD_VAL, + ) +# data_loader def create_callbacks(config, save_dir): filename = f"{config.run_name}_" + str(uuid.uuid1()).split("-")[0] @@ -728,142 +783,80 @@ def create_callbacks(config, save_dir): def main(args): - log_dir = os.getenv("STF_LOG_DIR") - if log_dir is None: - log_dir = "./data/STF_LOG_DIR" - print( - "Using default wandb log dir path of ./data/STF_LOG_DIR. This can be adjusted with the environment variable `STF_LOG_DIR`" - ) + # Initialization and Setup + log_dir = os.getenv("STF_LOG_DIR", "./data/STF_LOG_DIR") if not os.path.exists(log_dir): os.makedirs(log_dir) if args.wandb: import wandb - project = os.getenv("STF_WANDB_PROJ") entity = os.getenv("STF_WANDB_ACCT") - assert ( - project is not None and entity is not None - ), "Please set environment variables `STF_WANDB_ACCT` and `STF_WANDB_PROJ` with \n\ - your wandb user/organization name and project title, respectively." - experiment = wandb.init( - project=project, - entity=entity, - config=args, - dir=log_dir, - reinit=True, - ) + experiment = wandb.init(project=project, entity=entity, config=args, dir=log_dir, reinit=True) config = wandb.config wandb.run.name = args.run_name wandb.run.save() - logger = pl.loggers.WandbLogger( - experiment=experiment, - save_dir=log_dir, - ) + logger = pl.loggers.WandbLogger(experiment=experiment, save_dir=log_dir) + + # Data Preparation + if args.dset == "stocks": + # Custom DataLoader for 'stocks' + train_loader = DataLoader(TimeSeriesDataset(data_folder='./data/train', context_length=args.context_points, forecast_length=args.target_points), batch_size=args.batch_size, shuffle=True) + test_loader = DataLoader(TimeSeriesDataset(data_folder='./data/test', context_length=args.context_points, forecast_length=args.target_points), batch_size=args.batch_size, shuffle=False) + oos_loader = DataLoader(TimeSeriesDataset(data_folder='./data/oos', context_length=args.context_points, forecast_length=args.target_points), batch_size=args.batch_size, shuffle=False) + else: + # Standard DataModule for other datasets + data_module, inv_scaler, scaler, null_val, plot_var_idxs, plot_var_names, pad_val = create_dset(args) - # Dset - ( - data_module, - inv_scaler, - scaler, - null_val, - plot_var_idxs, - plot_var_names, - pad_val, - ) = create_dset(args) - - # Model - args.null_value = null_val - args.pad_value = pad_val + # Model Training and Evaluation forecaster = create_model(args) - forecaster.set_inv_scaler(inv_scaler) - forecaster.set_scaler(scaler) - forecaster.set_null_value(null_val) - # Callbacks - callbacks = create_callbacks(args, save_dir=log_dir) - test_samples = next(iter(data_module.test_dataloader())) + if args.dset == "stocks": + # Custom Training Loop for 'stocks' + for epoch in range(args.epochs): + 1+1 + # Training Phase + # Include your training logic here using train_loader - if args.wandb and args.plot: - callbacks.append( - stf.plot.PredictionPlotterCallback( - test_samples, - var_idxs=plot_var_idxs, - var_names=plot_var_names, - pad_val=pad_val, - total_samples=min(args.plot_samples, args.batch_size), - ) - ) + # Validation Phase (Optional) + # Include your validation logic here using test_loader - if args.wandb and args.dset in ["mnist", "cifar"] and args.plot: - callbacks.append( - stf.plot.ImageCompletionCallback( - test_samples, - total_samples=min(16, args.batch_size), - mode="left-right" if config.dset == "mnist" else "flat", - ) - ) - - if args.wandb and args.dset == "copy" and args.plot: - callbacks.append( - stf.plot.CopyTaskCallback( - test_samples, - total_samples=min(16, args.batch_size), - ) - ) - - if args.wandb and args.model == "spacetimeformer" and args.attn_plot: - - callbacks.append( - stf.plot.AttentionMatrixCallback( - test_samples, - layer=0, - total_samples=min(16, args.batch_size), - ) - ) - - if args.wandb: - config.update(args) - logger.log_hyperparams(config) - - if args.val_check_interval <= 1.0: - val_control = {"val_check_interval": args.val_check_interval} + # Out-of-Sample Testing Phase (Optional) + # Include your testing logic here using oos_loader else: - val_control = {"check_val_every_n_epoch": int(args.val_check_interval)} - - trainer = pl.Trainer( - gpus=args.gpus, - callbacks=callbacks, - logger=logger if args.wandb else None, - accelerator="dp", - gradient_clip_val=args.grad_clip_norm, - gradient_clip_algorithm="norm", - overfit_batches=20 if args.debug else 0, - accumulate_grad_batches=args.accumulate, - sync_batchnorm=True, - limit_val_batches=args.limit_val_batches, - **val_control, - ) + # Standard Training with PyTorch Lightning for other datasets + forecaster.set_inv_scaler(inv_scaler) + forecaster.set_scaler(scaler) + forecaster.set_null_value(null_val) - # Train - trainer.fit(forecaster, datamodule=data_module) + # Callbacks and Trainer Configuration + callbacks = create_callbacks(args, save_dir=log_dir) + trainer = pl.Trainer( + gpus=args.gpus, + callbacks=callbacks, + logger=logger if args.wandb else None, + # ... additional trainer configurations ... + ) - # Test - trainer.test(datamodule=data_module, ckpt_path="best") + # Fitting the model + trainer.fit(forecaster, datamodule=data_module) - # Predict (only here as a demo and test) - # forecaster.to("cuda") - # xc, yc, xt, _ = test_samples - # yt_pred = forecaster.predict(xc, yc, xt) + # Testing the model + trainer.test(datamodule=data_module, ckpt_path="best") + # WANDB Experiment Finish (if applicable) if args.wandb: - experiment.finish() - + wandb.finish() if __name__ == "__main__": - # CLI parser = create_parser() args = parser.parse_args() + main(args) + +# if __name__ == "__main__": +# # CLI +# parser = create_parser() +# args = parser.parse_args() - for trial in range(args.trials): - main(args) +# for trial in range(args.trials): +# main(args) diff --git a/spacetimeformer/workbench.ipynb b/spacetimeformer/workbench.ipynb new file mode 100644 index 0000000..f0cc580 --- /dev/null +++ b/spacetimeformer/workbench.ipynb @@ -0,0 +1,44 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/Users/alecjeffery/Documents/Playgrounds/Python/spacetimeformer\n", + "Python 3.8.18\n" + ] + } + ], + "source": [ + "!pwd\n", + "!python --version" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "spacetimeformer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}