diff --git a/README.md b/README.md index 3588f14..1b3c81d 100644 --- a/README.md +++ b/README.md @@ -23,12 +23,7 @@ and to the video presentation [here](https://www.youtube.com/watch?v=Q7Q9o7ywXx8 ## Install TimeSHAP -##### Via Pip -``` -pip install timeshap -``` - -##### Via Github +#### [Recommended] Via Github Clone the repository into a local directory using: ``` git clone https://github.com/feedzai/timeshap.git @@ -41,8 +36,12 @@ cd timeshap pip install . ``` +#### Via Pip +``` +pip install timeshap +``` -##### Test your installation +#### Test your installation Start a Python session in your terminal using ``` diff --git a/notebooks/AReM/AReM.ipynb b/notebooks/AReM/AReM.ipynb index 273dbab..db2d3bf 100644 --- a/notebooks/AReM/AReM.ipynb +++ b/notebooks/AReM/AReM.ipynb @@ -69,7 +69,7 @@ { "data": { "text/plain": [ - "'1.0.2'" + "'1.1.0'" ] }, "execution_count": 2, @@ -124,80 +124,80 @@ "name": "stdout", "output_type": "stream", "text": [ - "sitting/dataset5.csv ------ (480, 7)\n", - "sitting/dataset11.csv ------ (480, 7)\n", - "sitting/dataset3.csv ------ (480, 7)\n", - "sitting/dataset10.csv ------ (480, 7)\n", - "sitting/dataset12.csv ------ (480, 7)\n", - "sitting/dataset4.csv ------ (480, 7)\n", - "sitting/dataset13.csv ------ (480, 7)\n", - "sitting/dataset9.csv ------ (480, 7)\n", - "sitting/dataset2.csv ------ (480, 7)\n", - "sitting/dataset15.csv ------ (480, 7)\n", - "sitting/dataset1.csv ------ (480, 7)\n", - "sitting/dataset14.csv ------ (480, 7)\n", - "sitting/dataset6.csv ------ (480, 7)\n", - "sitting/dataset7.csv ------ (480, 7)\n", - "standing/dataset5.csv ------ (480, 7)\n", - "standing/dataset11.csv ------ (480, 7)\n", - "standing/dataset3.csv ------ (480, 7)\n", - "standing/dataset8.csv ------ (480, 7)\n", - "standing/dataset10.csv ------ (480, 7)\n", - "standing/dataset12.csv ------ (480, 7)\n", - "standing/dataset4.csv ------ (480, 7)\n", - "standing/dataset13.csv ------ (480, 7)\n", - "standing/dataset9.csv ------ (480, 7)\n", - "standing/dataset2.csv ------ (480, 7)\n", - "standing/dataset15.csv ------ (480, 7)\n", - "standing/dataset1.csv ------ (480, 7)\n", - "standing/dataset14.csv ------ (480, 7)\n", - "standing/dataset6.csv ------ (480, 7)\n", - "standing/dataset7.csv ------ (480, 7)\n", + "walking/dataset7.csv ------ (480, 7)\n", + "walking/dataset6.csv ------ (480, 7)\n", + "walking/dataset4.csv ------ (480, 7)\n", "walking/dataset5.csv ------ (480, 7)\n", - "walking/dataset11.csv ------ (480, 7)\n", + "walking/dataset1.csv ------ (480, 7)\n", + "walking/dataset2.csv ------ (480, 7)\n", "walking/dataset3.csv ------ (480, 7)\n", - "walking/dataset8.csv ------ (480, 7)\n", "walking/dataset10.csv ------ (480, 7)\n", - "walking/dataset12.csv ------ (480, 7)\n", - "walking/dataset4.csv ------ (480, 7)\n", + "walking/dataset11.csv ------ (480, 7)\n", "walking/dataset13.csv ------ (480, 7)\n", - "walking/dataset9.csv ------ (480, 7)\n", - "walking/dataset2.csv ------ (480, 7)\n", + "walking/dataset12.csv ------ (480, 7)\n", "walking/dataset15.csv ------ (480, 7)\n", - "walking/dataset1.csv ------ (480, 7)\n", "walking/dataset14.csv ------ (480, 7)\n", - "walking/dataset6.csv ------ (480, 7)\n", - "walking/dataset7.csv ------ (480, 7)\n", - "cycling/dataset5.csv ------ (480, 7)\n", - "cycling/dataset11.csv ------ (480, 7)\n", - "cycling/dataset3.csv ------ (480, 7)\n", - "cycling/dataset8.csv ------ (480, 7)\n", - "cycling/dataset10.csv ------ (480, 7)\n", - "cycling/dataset12.csv ------ (480, 7)\n", - "cycling/dataset4.csv ------ (480, 7)\n", - "cycling/dataset13.csv ------ (480, 7)\n", - "cycling/dataset9.csv ------ (480, 7)\n", - "cycling/dataset2.csv ------ (480, 7)\n", - "cycling/dataset15.csv ------ (480, 7)\n", - "cycling/dataset1.csv ------ (480, 7)\n", - "cycling/dataset14.csv ------ (480, 7)\n", - "cycling/dataset6.csv ------ (480, 7)\n", - "cycling/dataset7.csv ------ (480, 7)\n", + "walking/dataset8.csv ------ (480, 7)\n", + "walking/dataset9.csv ------ (480, 7)\n", + "standing/dataset7.csv ------ (480, 7)\n", + "standing/dataset6.csv ------ (480, 7)\n", + "standing/dataset4.csv ------ (480, 7)\n", + "standing/dataset5.csv ------ (480, 7)\n", + "standing/dataset1.csv ------ (480, 7)\n", + "standing/dataset2.csv ------ (480, 7)\n", + "standing/dataset3.csv ------ (480, 7)\n", + "standing/dataset10.csv ------ (480, 7)\n", + "standing/dataset11.csv ------ (480, 7)\n", + "standing/dataset13.csv ------ (480, 7)\n", + "standing/dataset12.csv ------ (480, 7)\n", + "standing/dataset15.csv ------ (480, 7)\n", + "standing/dataset14.csv ------ (480, 7)\n", + "standing/dataset8.csv ------ (480, 7)\n", + "standing/dataset9.csv ------ (480, 7)\n", + "sitting/dataset7.csv ------ (480, 7)\n", + "sitting/dataset6.csv ------ (480, 7)\n", + "sitting/dataset4.csv ------ (480, 7)\n", + "sitting/dataset5.csv ------ (480, 7)\n", + "sitting/dataset1.csv ------ (480, 7)\n", + "sitting/dataset2.csv ------ (480, 7)\n", + "sitting/dataset3.csv ------ (480, 7)\n", + "sitting/dataset10.csv ------ (480, 7)\n", + "sitting/dataset11.csv ------ (480, 7)\n", + "sitting/dataset13.csv ------ (480, 7)\n", + "sitting/dataset12.csv ------ (480, 7)\n", + "sitting/dataset15.csv ------ (480, 7)\n", + "sitting/dataset14.csv ------ (480, 7)\n", + "sitting/dataset9.csv ------ (480, 7)\n", + "lying/dataset7.csv ------ (480, 7)\n", + "lying/dataset6.csv ------ (480, 7)\n", + "lying/dataset4.csv ------ (480, 7)\n", "lying/dataset5.csv ------ (480, 7)\n", - "lying/dataset11.csv ------ (480, 7)\n", + "lying/dataset1.csv ------ (480, 7)\n", + "lying/dataset2.csv ------ (480, 7)\n", "lying/dataset3.csv ------ (480, 7)\n", - "lying/dataset8.csv ------ (480, 7)\n", "lying/dataset10.csv ------ (480, 7)\n", - "lying/dataset12.csv ------ (480, 7)\n", - "lying/dataset4.csv ------ (480, 7)\n", + "lying/dataset11.csv ------ (480, 7)\n", "lying/dataset13.csv ------ (480, 7)\n", - "lying/dataset9.csv ------ (480, 7)\n", - "lying/dataset2.csv ------ (480, 7)\n", + "lying/dataset12.csv ------ (480, 7)\n", "lying/dataset15.csv ------ (480, 7)\n", - "lying/dataset1.csv ------ (480, 7)\n", "lying/dataset14.csv ------ (480, 7)\n", - "lying/dataset6.csv ------ (480, 7)\n", - "lying/dataset7.csv ------ (480, 7)\n" + "lying/dataset8.csv ------ (480, 7)\n", + "lying/dataset9.csv ------ (480, 7)\n", + "cycling/dataset7.csv ------ (480, 7)\n", + "cycling/dataset6.csv ------ (480, 7)\n", + "cycling/dataset4.csv ------ (480, 7)\n", + "cycling/dataset5.csv ------ (480, 7)\n", + "cycling/dataset1.csv ------ (480, 7)\n", + "cycling/dataset2.csv ------ (480, 7)\n", + "cycling/dataset3.csv ------ (480, 7)\n", + "cycling/dataset10.csv ------ (480, 7)\n", + "cycling/dataset11.csv ------ (480, 7)\n", + "cycling/dataset13.csv ------ (480, 7)\n", + "cycling/dataset12.csv ------ (480, 7)\n", + "cycling/dataset15.csv ------ (480, 7)\n", + "cycling/dataset14.csv ------ (480, 7)\n", + "cycling/dataset8.csv ------ (480, 7)\n", + "cycling/dataset9.csv ------ (480, 7)\n" ] } ], @@ -524,125 +524,67 @@ "name": "stderr", "output_type": "stream", "text": [ - " 12%|█▎ | 1/8 [00:02<00:15, 2.23s/it]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train loss: 0.6997937560081482 --- Test loss 0.5927574634552002 \n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r", - " 25%|██▌ | 2/8 [00:04<00:12, 2.01s/it]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train loss: 0.5918577313423157 --- Test loss 0.4946227967739105 \n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r", - " 38%|███▊ | 3/8 [00:05<00:08, 1.72s/it]" + " 25%|██▌ | 2/8 [00:00<00:01, 5.63it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Train loss: 0.49630260467529297 --- Test loss 0.40795645117759705 \n" + "Train loss: 0.7000859975814819 --- Test loss 0.5926060080528259 \n", + "Train loss: 0.593891441822052 --- Test loss 0.4921618103981018 \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\r", - " 50%|█████ | 4/8 [00:07<00:06, 1.69s/it]" + " 50%|█████ | 4/8 [00:00<00:00, 5.78it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Train loss: 0.4160643219947815 --- Test loss 0.343948632478714 \n" + "Train loss: 0.4992573857307434 --- Test loss 0.40382787585258484 \n", + "Train loss: 0.41805166006088257 --- Test loss 0.34258148074150085 \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\r", - " 62%|██████▎ | 5/8 [00:10<00:06, 2.16s/it]" + " 75%|███████▌ | 6/8 [00:01<00:00, 5.82it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Train loss: 0.3660268187522888 --- Test loss 0.2917708158493042 \n" + "Train loss: 0.3631644546985626 --- Test loss 0.30080685019493103 \n", + "Train loss: 0.32556360960006714 --- Test loss 0.260188490152359 \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\r", - " 75%|███████▌ | 6/8 [00:12<00:04, 2.37s/it]" + "100%|██████████| 8/8 [00:01<00:00, 5.51it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Train loss: 0.334475576877594 --- Test loss 0.2376544028520584 \n" + "Train loss: 0.28500068187713623 --- Test loss 0.22024422883987427 \n", + "Train loss: 0.23866622149944305 --- Test loss 0.18811321258544922 \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\r", - " 88%|████████▊ | 7/8 [00:16<00:02, 2.69s/it]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train loss: 0.29900506138801575 --- Test loss 0.1931639164686203 \n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 8/8 [00:19<00:00, 2.43s/it]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train loss: 0.25920480489730835 --- Test loss 0.17529389262199402 \n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "100%|██████████| 8/8 [00:01<00:00, 5.64it/s]\n" ] } ], @@ -764,12 +706,12 @@ " \n", " \n", " 0\n", - " 0.095791\n", - " -0.531224\n", - " 0.178574\n", - " -0.416398\n", - " 0.129324\n", - " -0.392722\n", + " 0.084481\n", + " -0.520933\n", + " 0.143288\n", + " -0.410888\n", + " 0.108927\n", + " -0.388381\n", " \n", " \n", "\n", @@ -777,10 +719,10 @@ ], "text/plain": [ " p_avg_rss12_normalized p_var_rss12_normalized p_avg_rss13_normalized \\\n", - "0 0.095791 -0.531224 0.178574 \n", + "0 0.084481 -0.520933 0.143288 \n", "\n", " p_var_rss13_normalized p_avg_rss23_normalized p_var_rss23_normalized \n", - "0 -0.416398 0.129324 -0.392722 " + "0 -0.410888 0.108927 -0.388381 " ] }, "execution_count": 18, @@ -820,19 +762,19 @@ { "data": { "text/plain": [ - "array([[0.0000e+00, 4.0875e+01, 5.0000e-01, 1.4625e+01, 1.0500e+00,\n", - " 1.5250e+01],\n", - " [2.5000e+02, 4.0585e+01, 6.0500e-01, 1.4875e+01, 7.9000e-01,\n", - " 1.4875e+01],\n", - " [5.0000e+02, 3.9835e+01, 7.5500e-01, 1.4125e+01, 8.7000e-01,\n", - " 1.7000e+01],\n", + "array([[0.0000e+00, 4.0500e+01, 5.0000e-01, 1.4500e+01, 9.4000e-01,\n", + " 1.5500e+01],\n", + " [2.5000e+02, 4.0000e+01, 8.3000e-01, 1.5000e+01, 9.4000e-01,\n", + " 1.5000e+01],\n", + " [5.0000e+02, 4.0000e+01, 7.1000e-01, 1.5000e+01, 8.7000e-01,\n", + " 1.6500e+01],\n", " ...,\n", - " [1.1925e+05, 3.9625e+01, 6.6000e-01, 1.4400e+01, 9.7000e-01,\n", - " 1.4165e+01],\n", - " [1.1950e+05, 3.9500e+01, 5.0000e-01, 1.2710e+01, 1.2050e+00,\n", + " [1.1925e+05, 4.0000e+01, 8.3000e-01, 1.5330e+01, 9.4000e-01,\n", + " 1.4000e+01],\n", + " [1.1950e+05, 3.9500e+01, 5.0000e-01, 1.4000e+01, 1.2200e+00,\n", " 1.5000e+01],\n", - " [1.1975e+05, 3.9625e+01, 6.6000e-01, 1.4000e+01, 1.0150e+00,\n", - " 1.4750e+01]])" + " [1.1975e+05, 3.9750e+01, 5.0000e-01, 1.4500e+01, 1.0900e+00,\n", + " 1.4250e+01]])" ] }, "execution_count": 20, @@ -924,29 +866,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1022,6 +969,9 @@ "name": "stdout", "output_type": "stream", "text": [ + "The defined path for pruning data already exists and the append option is turned off. TimeSHAP will only read from this file and will not create new explanation data\n", + "The defined path for event explanations already exists and the append option is turned off. TimeSHAP will only read from this file and will not create new explanation data\n", + "The defined path for feature explanations already exists and the append option is turned off. TimeSHAP will only read from this file and will not create new explanation data\n", "Calculating pruning algorithm\n", "Calculating event data\n", "Calculating feat data\n", @@ -1058,14 +1008,14 @@ " \n", " 0\n", " 0.05\n", - " 12.75\n", - " 1.258306\n", + " 14.50\n", + " 3.415650\n", " \n", " \n", " 1\n", " 0.075\n", - " 11.25\n", - " 1.258306\n", + " 12.75\n", + " 2.629956\n", " \n", " \n", " 2\n", @@ -1079,8 +1029,8 @@ ], "text/plain": [ " Tolerance Mean Std\n", - "0 0.05 12.75 1.258306\n", - "1 0.075 11.25 1.258306\n", + "0 0.05 14.50 3.415650\n", + "1 0.075 12.75 2.629956\n", "2 No Pruning 480.00 0.000000" ] }, @@ -1111,29 +1061,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1223,29 +1178,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1309,29 +1269,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1393,29 +1358,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1477,29 +1447,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1609,14 +1584,14 @@ " \n", " 0\n", " 0.05\n", - " 12.75\n", - " 1.258306\n", + " 14.50\n", + " 3.415650\n", " \n", " \n", " 1\n", " 0.075\n", - " 11.25\n", - " 1.258306\n", + " 12.75\n", + " 2.629956\n", " \n", " \n", " 2\n", @@ -1630,8 +1605,8 @@ ], "text/plain": [ " Tolerance Mean Std\n", - "0 0.05 12.75 1.258306\n", - "1 0.075 11.25 1.258306\n", + "0 0.05 14.50 3.415650\n", + "1 0.075 12.75 2.629956\n", "2 No Pruning 480.00 0.000000" ] }, @@ -1665,29 +1640,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1749,29 +1729,36 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ @@ -1822,14 +1807,6 @@ "metadata": {}, "outputs": [], "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "99559082", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -1848,7 +1825,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.8.19" } }, "nbformat": 4, diff --git a/requirements/main.txt b/requirements/main.txt index 50ba133..6fd6cfb 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -1,9 +1,9 @@ -pandas>=1.3.2 +pandas>=1.3.2,<=1.5.3 scikit-learn>=0.23.2 seaborn>=0.11.1 matplotlib>=3.3.2 numpy>=1.19.2 -shap>=0.37.0 +shap>=0.37.0,<=0.42.1 scipy>=1.5.2 plotly>=4.6 seaborn>=0.11.2 diff --git a/setup.py b/setup.py index 7224192..ee1c124 100644 --- a/setup.py +++ b/setup.py @@ -46,26 +46,25 @@ def stream_requirements(fd): version=__version__, description="KernelSHAP adaptation for recurrent models.", keywords=['explainability', 'TimeShap'], - long_description=(README_PATH).read_text(), long_description_content_type="text/markdown", - author="Feedzai", url="https://github.com/feedzai/timeshap", - package_dir={'': 'src'}, packages=find_packages('src', exclude=['tests', 'tests.*']), package_data={ '': ['*.yaml, *.yml'], }, include_package_data=True, - python_requires='>=3.6', - install_requires=requirements, - zip_safe=False, - test_suite='tests', tests_require=requirements_test, + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + ], ) diff --git a/src/timeshap/version.py b/src/timeshap/version.py index 551c4c4..ee4eb0d 100644 --- a/src/timeshap/version.py +++ b/src/timeshap/version.py @@ -1,4 +1,3 @@ """File to keep the package version in one place""" -__version__ = "1.0.4" +__version__ = "1.1.0" __version_info__ = tuple(__version__.split(".")) -