Skip to content

Commit

Permalink
Simplify adding new data distributions (#12)
Browse files Browse the repository at this point in the history
* more reuse in kernel management

* refactored PC and adapted tests

* fix small bug

* update readme and benchmarks

* Update gpucsl/pc/pc.py

Co-authored-by: PeterTsayun <40639972+PeterTsayun@users.noreply.github.com>

* Update gpucsl/pc/pc.py

Co-authored-by: PeterTsayun <40639972+PeterTsayun@users.noreply.github.com>

* requested changes

* put vaildation into extra file

* update docu

* added examples and adapted documentation

* Update docs/Public-api.md

Co-authored-by: PeterTsayun <40639972+PeterTsayun@users.noreply.github.com>

* Update docs/examples/README.md

Co-authored-by: PeterTsayun <40639972+PeterTsayun@users.noreply.github.com>

* apply linting

* remove pc from docs

Co-authored-by: PeterTsayun <40639972+PeterTsayun@users.noreply.github.com>
  • Loading branch information
BraunTom and PeterTsayun authored Apr 4, 2022
1 parent 110e121 commit 833bca9
Show file tree
Hide file tree
Showing 21 changed files with 480 additions and 319 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Note, that `GPUCSL` provides kernel implementations that cover conditional indep

## <a name="usage"></a> Usage

Linux and a NVIDIA GPU with CUDA are required. We support running on multiple GPUs (experimental; for now, only for Gaussian CI kernel - `DataDistribution.GAUSSIAN`).
Linux and a NVIDIA GPU with CUDA are required. We support running on multiple GPUs (experimental; for now, only for Gaussian CI kernel - `GaussianPC`).

### CLI

Expand All @@ -33,6 +33,7 @@ With the CLI, the PC algorithm is executed on the specified datasets. Three outp
- {dataset}.gml - the resulting CPDAG containing the causal relationships
- {dataset}_pmax.csv - the maximum pvalues used for the conditional independence tests
- {dataset}_sepset.csv - the separation sets for the removed edges
- {dataset}_config.txt - the parameters the CLI got called with

All paths you give to the CLI are relative to your current directory.
An example call for `GPUCSL` with a CI test for multivariate normal or Gaussian distributed data could look like this (assuming your data is in "./data.csv"):
Expand All @@ -45,35 +46,35 @@ python3 -m gpucsl --gaussian -d ./data.csv -o . -l 3

`GPUCSL` provides a python API for:

- `pc` (`DataDistribution.GAUSSIAN`, `DataDistribution.DISCRETE`) - implements the full PC algorithm for discrete and Gaussian data. Outputs the CPDAG from observational data. Similar to the CLI.
- `GaussianPC` - implements the full PC algorithm for multivariate normal data. Outputs the CPDAG from observational data. Similar to the CLI.
- `DiscretePC` -implements the full PC algorithm for discrete data. Outputs the CPDAG from observational data. Similar to the CLI.
- `discover_skeleton_gpu_gaussian` - determines the undirected skeleton graph for gaussian distribution
- `discover_skeleton_gpu_discrete` - determines the undirected skeleton graph for discrete distribution
- `orient_edges` - orients the edges of the undirected skeleton graph by detection of v-structures and application of Meek's orientation rules. Outputs the CPDAG from skeleton.

Additional detail is found in the [API description](https://github.com/hpi-epic/gpucsl/blob/main/docs/Public-api.md).

The following code snippet provides a small example for calling the `pc` function:
The following code snippet provides a small example for using `GaussianPC`:
```python
import numpy as np
from gpucsl.pc.pc import pc, DataDistribution
from gpucsl.pc.pc import GaussianPC

samples = np.random.rand(1000, 10)
max_level = 3
alpha = 0.05
((directed_graph, separation_sets, pmax, discover_skeleton_runtime,
edge_orientation_runtime, discover_skeleton_kernel_runtime),
pc_runtime) = pc(samples,
DataDistribution.GAUSSIAN,
pc_runtime) = GaussianPC(samples,
max_level,
alpha)
alpha).set_distribution_specific_options().execute()

```

Additional usage examples can be found in `benchmarks/benchmark_gpucsl.py`.
Additional usage examples can be found in `docs/examples/`.

### Multi GPU support

Multi GPU support is currently only implemented for the gaussian CI kernel (`DataDistribution.GAUSSIAN`) for skeleton discovery. The adjacency matrix (skeleton) is partitioned horizontally, and each GPU executes the CI tests on the assigned partition. For example, in the case of the dataset with 6 variables and 3 GPUs, the first GPU executes CI tests on edges 0-i, 1-i, where i is in {0..5\} (0-indexing), the second GPU executes CI tests on edges 2-i, 3-i and so on.
Multi GPU support is currently only implemented for the gaussian CI kernel (`GaussianPC`) for skeleton discovery. The adjacency matrix (skeleton) is partitioned horizontally, and each GPU executes the CI tests on the assigned partition. For example, in the case of the dataset with 6 variables and 3 GPUs, the first GPU executes CI tests on edges 0-i, 1-i, where i is in {0..5\} (0-indexing), the second GPU executes CI tests on edges 2-i, 3-i and so on.

In case of an edge being deleted on multiple GPUs in the same level (for example, the edge 1-3 is deleted on the first GPU, the edge 3-1 is deleted on the second GPU in the example above), the separation set with the highest p-value is written to the end result (along with the corresponding p-value).

Expand Down
26 changes: 10 additions & 16 deletions benchmarks/benchmark_gpucsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
import argparse
import csv

from gpucsl.pc.pc import pc
from gpucsl.pc.kernel_management import (
DiscreteKernel,
Kernels,
CompactKernel,
GaussianKernel,
)
from gpucsl.pc.kernel_management import Kernels
import cupy as cp
from timeit import default_timer as timer
from sklearn.preprocessing import OrdinalEncoder

from gpucsl.pc.helpers import correlation_matrix_of
from gpucsl.pc.pc import DataDistribution
from gpucsl.pc.pc import GaussianPC, DiscretePC


SCRIPT_PATH = Path(__file__).parent.resolve()
Expand Down Expand Up @@ -115,16 +109,14 @@ def run_benchmark_gaussian(samples, max_level, devices, sync_device):
max_level, samples.shape[1], devices
)

(pc_result, pc_runtime) = pc(
pc = GaussianPC(
samples,
DataDistribution.GAUSSIAN,
max_level,
0.05,
gaussian_correlation_matrix=correlation_matrix,
kernels=kernels,
devices=devices,
sync_device=sync_device,
)
).set_distribution_specific_options(devices, sync_device, correlation_matrix)

(pc_result, pc_runtime) = pc.execute()

duration_incl_compilation = timer() - start
return (
Expand All @@ -148,8 +140,10 @@ def run_benchmark_discrete(samples, max_level, devices, sync_device):
samples.shape[0],
)

(pc_result, pc_runtime) = pc(
samples, DataDistribution.DISCRETE, max_level, alpha=0.05, kernels=kernels
(pc_result, pc_runtime) = (
DiscretePC(samples, max_level, alpha=0.05, kernels=kernels)
.set_distribution_specific_options()
.execute()
)

duration_incl_compilation = timer() - start
Expand Down
24 changes: 8 additions & 16 deletions docs/How-to-implement-a-new-ci-test.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,16 @@ To write your own Kernel Management you should inherit from the abstract base cl
is_debug: bool = False,
should_log: bool = False,
):
super().__init__(is_debug=is_debug, should_log=should_log)
self.your_choosen_parameters = your_choosen_parameters
# calculate all the kernel function signature names you later on want to be able to call (something needed by CuPy)
# has to be an array, but if you only need one signature you can change this to be an array with one entry
kernel_function_signatures = [
self.kernel_function_signature_for_level(level)
for level in range(0, max_level + 1)
]
# pass the kernel function signatures and compile the CUDA code
self.define_module("gaussian_ci.cu", kernel_function_signatures)
super().__init__(is_debug=is_debug, should_log=should_log)
```
- kernel_function_name_for_level: Return the name of the kernel for the given level. Can return different names for different levels (as you maybe have written
optimizations for a specific level in form of an extra CUDA function, as we did for DiscreteKernel. Otherwise, you can also return just a static name like the CompactKernel)
- template_specification_for_level: If you templated your kernel in the CUDA code CuPy needs the filled-out template to access the instantiated functions. Here you provide a string with the parameters you give the template for the current level.
- grid_and_block_mapping: Defines how your kernel is mapped into grids and blocks
- cuda_file: Return the filename of the cuda file containing the kernel you want to use (the path is relative to the `gpucsl/pc/cuda` directory)
- every_accessable_function_signature: Return an array containing every function signature that should be accessable later on. (You probably want use the _kernel_function_signature_for_level_ method here. It returns the kernel function signature for a level based on your _kernel_function_name_for_level_ and _template_specification_for_level_ methods)

Optional:
- pre_kernel_launch_check: A hook executed before the raw CuPy kernel is accessed and run. Can be used to check if everything is ok before the launch happens.
Expand All @@ -73,7 +66,7 @@ Following is a template you can use for the skeleton discovery. Just copy and ad

```
@timed
def discover_skeleton_gpu_discrete(
def discover_skeleton(
skeleton: np.ndarray,
data: np.ndarray,
alpha: float,
Expand Down Expand Up @@ -131,10 +124,9 @@ def discover_skeleton_gpu_discrete(

Note: make sure the types you initialize your data with on Python side match the data types the CUDA kernel takes. An example in the template is d_skeleton which is initialized as cp.uint16.


After implementing the kernel discovery you can extend the pc function (`gpucsl/pc/pc.py`). First, add your distribution to the DataDistribution enum in the same file. Then in pc test for your DataDistribution and call your skeleton discovery like it is done for _DataDistribution.GAUSSIAN_ or _DataDistribution.DISCRETE_. Depending on whether you need more arguments you have to extend the pc functions argument list.

As long as you return a correct SkeletonResult from your skeleton discovery you do not need to change anything else. The edge orientation should work.
After implementing the kernel discovery you inherit from the abstract _PC_(`gpucsl/pc/pc.py`) class. Now implement the methods:
- _set_distribution_specific_options_ method: Takes the arguments specific to your distribution and saves them as instance variables. You can return the object itself in the end in order to chain the execute call to it easier.
- _discover_skeleton_: Returns the result of the earlier implemented skeleton discovery. You can just pass the locally saved parameters to your earlier implemented _discover_skeleton_ function or implement the complete skeleton discovery here and just use the instance variables directly


## Extend CLI (optional)
Expand All @@ -144,5 +136,5 @@ You mainly have to extend the command line parser (gpucsl/pc/command_line_parser
basically used the argparse package, so please refer to https://docs.python.org/3/library/argparse.html on how to use it.

One mandatory step will be to add a new distribution flag (example for distribution flag: gaussian). Then you have to extend the function _gpucsl_cli_
(`gpucsl/cli/cli.py`). The main points are error checking for new parameters you introduced and passing the sanitized values to _the run_on_dataset_ function.
(`gpucsl/cli/cli.py`). The main points are error checking for new parameters you introduced, instanciate your implemented subclass of _PC_ and pass the sanitized values to the class by creating it with the general parameters the constructor takes and calling the _set_distribution_specific_options_ method with the distribution specific options as arguments.
You do not need to change anything to write the results as done currently.
98 changes: 76 additions & 22 deletions docs/Public-api.md
Original file line number Diff line number Diff line change
@@ -1,38 +1,92 @@
## pc (gpucsl/pc/pc.py)

pc(
data: np.ndarray,
data_distribution: DataDistribution,
max_level: int,
alpha=0.05,
kernels=None,
is_debug: bool = False,
should_log: bool = False,
devices: List[int] = [0],
sync_device: int = None,
gaussian_correlation_matrix: np.ndarray = None,
) -> PCResult

Executes the PC algorithm.

### Parameters
## class PC (gpucsl/pc/pc.py)

The abstract base class to inherit from to add a new data distribution.

### \_\_init\_\_
def \_\_init\_\_(
self,
data: np.ndarray,
max_level: int,
alpha=0.05,
kernels=None,
is_debug: bool = False,
should_log: bool = False
)

#### Parameters
- data: The data to analyze
- data_distribution: Either DataDistribution.DISCRETE or DataDistribution.GAUSSIAN depending on the assumed distribution of the data
- max_level: max level until which the pc algorithm will run (inclusive). Depending on the max level data structures will get allocated on the GPU, so you want to keep it small to avoid out of memory problems
- alpha: Alpha value for the statistical tests
- kernels: You can compile the kernels that should be used yourself and pass them to the function. Used for time measurements where the compile time should be excluded. Leave None and GPUCSL will compile the kernels for you
- is_debug: If set to true kernels will get compiled in debug mode
- should_log: Sets a macro 'LOG' while compiling the CUDA kernels. Can be used for custom logging from kernels

### (abstract) discover_skeleton

Subclasses should implement the skeleton discovery for their respective distribution.

### (abstract) set_distribution_specific_options

Subclasses should get their specific paramters here, validate them and save them as instance variables. As a convenience this method should return the current object so the execute method can get chained to it.

### execute
execute()

#### Returns

The CPDAG that results from causal structure learning on your data, the separation sets, the maximum p values, and time measurements for the skeleton discovery and edge orientation of the pc algorithm, as well as time measurement for the execution of the kernels.

#### Return Value
- PCResult

Executes the pc algorithm. Presuppose you run _set_distribution_specific_options_ before!




<br/><br/>
## class GaussianPC(PC) (gpucsl/pc/pc.py)

Concrete implementation of the PC algorithm for the multivariate normal data distribution.

### set_distribution_specific_options
set_distribution_specific_options(self, devices: List[int] = [0], sync_device: int = None, correlation_matrix: np.ndarray = None)

#### Parameters
- devices: Device IDs of GPUs to be used.
- sync_device: Device ID of the GPU used for state synchronization in the multi GPU case (Notice: sync_device has to be in the devices list!)
gaussian_correlation_matrix: A correlation matrix can be passed so time measurements do not inlcude the calculation. Only possible when using DataDistribution.GAUSSIAN. If None given GPUCSL calculates it itself.
- correlation_matrix: The correlation matrix calculated from data

#### Returns
- Itself for convenience

#### Return Value
- self




<br/><br/>
## class DiscretePC(PC) (gpucsl/pc/pc.py)

Concrete implementation of the PC algorithm for the discrete data distribution.

### set_distribution_specific_options
set_distribution_specific_options(self, memory_restriction=None)

#### Parameters
- memory_restriction: The maximum space to allocate for the working structures. Small values decrease the parallelisation. If None given defaults to 95% of the total available memory on GPU.

#### Returns
- Itself for convenience

#### Return Value
- self

### Returns

The CPDAG that results from causal structure learning on your data, the separation sets, the maximum p values, and time measurements for the skeleton discovery and edge orientation of the pc algorithm, as well as time measurement for the execution of the kernels.

### Return Value
- PCResult

<br/><br/>
## discover_skeleton_gpu_gaussian (gpucsl/pc/discover_skeleton_gaussian.py)
Expand Down
3 changes: 3 additions & 0 deletions docs/examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
All examples in these files are executable from the project main directory.
For the examples to work it is assumed that the `data` directory exists in the project's main directory and it contains `alarm/alarm.csv` and `coolingData/coolingData.csv`.
(To get the data files execute the `download-data.sh` script from the scripts folder - see `download-data.sh` section in `README.md` in the top directory for details)
20 changes: 20 additions & 0 deletions docs/examples/discrete-pc-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd
from gpucsl.pc.pc import DiscretePC

samples = pd.read_csv("data/alarm/alarm.csv", header=None).to_numpy()
max_level = 3
alpha = 0.05

(
(
directed_graph,
separation_sets,
pmax,
discover_skeleton_runtime,
edge_orientation_runtime,
discover_skeleton_kernel_runtime,
),
pc_runtime,
) = (
DiscretePC(samples, max_level, alpha).set_distribution_specific_options().execute()
)
20 changes: 20 additions & 0 deletions docs/examples/edge-orientation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd
from gpucsl.pc.pc import GaussianPC
import networkx as nx
from gpucsl.pc.edge_orientation.edge_orientation import orient_edges

samples = pd.read_csv("data/coolingData/coolingData.csv", header=None).to_numpy()
max_level = 3
alpha = 0.05

# you will need the skeleton and separation sets from the skeleton discovery

pc = GaussianPC(samples, max_level, alpha).set_distribution_specific_options()

((skeleton, separation_sets, _, _), _) = pc.discover_skeleton()

# do stuff

(directed_graph, edge_orientation_time) = orient_edges(
nx.DiGraph(skeleton), separation_sets
)
20 changes: 20 additions & 0 deletions docs/examples/gaussian-pc-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd
from gpucsl.pc.pc import GaussianPC

samples = pd.read_csv("data/coolingData/coolingData.csv", header=None).to_numpy()
max_level = 3
alpha = 0.05

(
(
directed_graph,
separation_sets,
pmax,
discover_skeleton_runtime,
edge_orientation_runtime,
discover_skeleton_kernel_runtime,
),
pc_runtime,
) = (
GaussianPC(samples, max_level, alpha).set_distribution_specific_options().execute()
)
31 changes: 31 additions & 0 deletions docs/examples/gaussian-skeleton-discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pandas as pd
from gpucsl.pc.discover_skeleton_gaussian import discover_skeleton_gpu_gaussian
from gpucsl.pc.helpers import init_pc_graph
from gpucsl.pc.pc import GaussianPC

samples = pd.read_csv("data/coolingData/coolingData.csv", header=None).to_numpy()
max_level = 3
alpha = 0.05

# way 1: use the PC class

pc = GaussianPC(samples, max_level, alpha).set_distribution_specific_options()

(
(skeleton, separation_sets, pmax, computation_time),
discovery_runtime,
) = pc.discover_skeleton()


# way 2: call the discover method yourself

graph = init_pc_graph(samples)
num_variables = samples.shape[1]
num_observations = samples.shape[0]

(
(skeleton, separation_sets, pmax, computation_time),
discovery_runtime,
) = discover_skeleton_gpu_gaussian(
graph, samples, None, alpha, max_level, num_variables, num_observations
)
Loading

0 comments on commit 833bca9

Please sign in to comment.