Skip to content

Commit f584007

Browse files
Merge pull request #89 from astrolabsoftware/serde
Update on pyspark3d: spatial partitioning
2 parents f4e5ad9 + 582bd3f commit f584007

25 files changed

+690
-36
lines changed

docs/03_partitioning_python.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,73 @@ Unfortunately, re-partitioning the space involves potentially large shuffle betw
1515

1616
## Available partitioning and partitioners
1717

18+
There are currently 2 partitioning implemented in the library:
19+
20+
- **Onion Partitioning:** See [here](https://github.com/astrolabsoftware/spark3D/issues/11) for a description. This is mostly intented for processing astrophysical data as it partitions the space in 3D shells along the radial axis, with the possibility of projecting into 2D shells (and then partitioning the shells using Healpix).
21+
- **Octree:** An octree extends a quadtree by using three orthogonal splitting planes to subdivide a tile into eight children. Like quadtrees, 3D Tiles allows variations to octrees such as non-uniform subdivision, tight bounding volumes, and overlapping children.
22+
23+
### Onion Partitioning
24+
25+
In the following example, we load `Point3D` data, and we re-partition it with the onion partitioning
26+
27+
```python
28+
from pyspark3d import load_user_conf, get_spark_session
29+
from pyspark3d.spatial3DRDD import Point3DRDD
30+
31+
# Load user config and the Spark session
32+
dic = load_user_conf()
33+
spark = get_spark_session(dicconf=dic)
34+
35+
# Load the data
36+
fn = "src/test/resources/astro_obs.fits"
37+
p3drdd = Point3DRDD(spark, fn, "Z_COSMO,RA,DEC", True, "fits", {"hdu": "1"})
38+
39+
# nPart is the wanted number of partitions.
40+
# Default is rdd.rawRDD() partition number.
41+
npart = 5
42+
gridtype = "LINEARONIONGRID"
43+
rdd_part = p3drdd.spatialPartitioningPython(gridtype, npart)
44+
```
45+
46+
| Raw data set | Re-partitioned data set
47+
|:---------:|:---------:
48+
| ![raw]({{ "/assets/images/onion_nopart_python.png" | absolute_url }}) | ![repartitioning]({{ "/assets/images/onion_part_python.png" | absolute_url }})
49+
50+
Color code indicates the partitions (all objects with the same color belong to the same partition).
51+
52+
### Octree Partitioning
53+
54+
In the following example, we load `ShellEnvelope` data (spheres), and we re-partition it with the octree partitioning
55+
56+
```python
57+
from pyspark3d import load_user_conf, get_spark_session
58+
from pyspark3d.spatial3DRDD import SphereRDD
59+
60+
# Load user config and the Spark session
61+
dic = load_user_conf()
62+
spark = get_spark_session(dicconf=dic)
63+
64+
# Load the data
65+
fn = "src/test/resources/cartesian_spheres.fits"
66+
srdd = SphereRDD(spark, fn, "x,y,z,radius", False, "fits", {"hdu": "1"})
67+
68+
# nPart is the wanted number of partitions (floored to a power of 8).
69+
# Default is rdd.rawRDD() partition number.
70+
npart = 10
71+
gridtype = "OCTREE"
72+
rdd_part = srdd.spatialPartitioningPython(gridtype, npart)
73+
```
74+
75+
76+
We advice to cache as well the re-partitioned sets, to speed-up future call by not performing the re-partitioning again. If you are short in memory, unpersist first the rawRDD before caching the repartitioned RDD.
77+
However keep in mind that if a large `nPart` decreases the cost of performing future queries (cross-match, KNN, ...), it increases the partitioning cost as more partitions implies more data shuffle between partitions. There is no magic number for `nPart` which applies in general, and you'll need to set it according to the needs of your problem. My only advice would be: re-partitioning is typically done once, queries can be multiple...
78+
79+
| Raw data set | Re-partitioned data set
80+
|:---------:|:---------:
81+
| ![raw]({{ "/assets/images/octree_nopart_python.png" | absolute_url }}) | ![repartitioning]({{ "/assets/images/octree_part_python.png" | absolute_url }})
82+
83+
Size of the markers is proportional to the radius size. Color code indicates the partitions (all objects with the same color belong to the same partition).
84+
85+
## Current benchmark
86+
1887
TBD

docs/03_partitioning_scala.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ val pointRDD_partitioned = pointRDD.spatialPartitioning(GridType.LINEARONIONGRID
5555
|:---------:|:---------:
5656
| ![raw]({{ "/assets/images/myOnionFigRaw.png" | absolute_url }}) | ![repartitioning]({{ "/assets/images/myOnionFig.png" | absolute_url }})
5757

58+
Color code indicates the partitions (all objects with the same color belong to the same partition).
59+
5860
### Octree Partitioning
5961

6062
In the following example, we load `ShellEnvelope` data (spheres), and we re-partition it with the octree partitioning
@@ -94,6 +96,8 @@ However keep in mind that if a large `nPart` decreases the cost of performing fu
9496
|:---------:|:---------:
9597
| ![raw]({{ "/assets/images/rawData_noOctree.png" | absolute_url }}) | ![repartitioning]({{ "/assets/images/rawData_withOctree.png" | absolute_url }})
9698

99+
Color code indicates the partitions (all objects with the same color belong to the same partition).
100+
97101
## Current benchmark
98102

99103
TBD
295 KB
Loading
263 KB
Loading
213 KB
Loading
204 KB
Loading
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2018 Julien Peloton
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pyspark.sql import SparkSession
15+
16+
import numpy as np
17+
18+
from pyspark3d import set_spark_log_level
19+
from pyspark3d import load_user_conf
20+
from pyspark3d import get_spark_session
21+
from pyspark3d.spatial3DRDD import SphereRDD
22+
23+
import argparse
24+
25+
def addargs(parser):
26+
""" Parse command line arguments for pyspark3d_part """
27+
28+
## Arguments
29+
parser.add_argument(
30+
'-inputpath', dest='inputpath',
31+
required=True,
32+
help='Path to a FITS file')
33+
34+
## Arguments
35+
parser.add_argument(
36+
'-hdu', dest='hdu',
37+
required=True,
38+
help='HDU index to load.')
39+
40+
## Arguments
41+
parser.add_argument(
42+
'-part', dest='part',
43+
default=None,
44+
help='Type of partitioning')
45+
46+
## Arguments
47+
parser.add_argument(
48+
'-npart', dest='npart',
49+
default=10,
50+
type=int,
51+
help='Number of partition')
52+
53+
## Arguments
54+
parser.add_argument(
55+
'--plot', dest='plot',
56+
action="store_true",
57+
help='Number of partition')
58+
59+
60+
if __name__ == "__main__":
61+
"""
62+
Re-partition RDD using OCTREE partitioning using pyspark3d
63+
"""
64+
parser = argparse.ArgumentParser(
65+
description="""
66+
Re-partition RDD using OCTREE partitioning using pyspark3d
67+
""")
68+
addargs(parser)
69+
args = parser.parse_args(None)
70+
71+
# Load user conf and Spark session
72+
dic = load_user_conf()
73+
spark = get_spark_session(dicconf=dic)
74+
75+
# Set logs to be quiet
76+
set_spark_log_level()
77+
78+
# Load raw data
79+
fn = args.inputpath
80+
rdd = SphereRDD(
81+
spark, fn, "x,y,z,radius", False, "fits", {"hdu": args.hdu})
82+
83+
# Perform the re-partitioning
84+
npart = args.npart
85+
gridtype = args.part
86+
87+
if gridtype is not None:
88+
rdd_part = rdd.spatialPartitioningPython(gridtype, npart)
89+
else:
90+
rdd_part = rdd.rawRDD().toJavaRDD().repartition(npart)
91+
92+
if not args.plot:
93+
count = rdd_part.count()
94+
print("{} elements".format(count))
95+
else:
96+
# Plot the result
97+
# Collect the data on driver -- just for visualisation purpose, do not
98+
# do that with full data set or you will destroy your driver!!
99+
import pylab as pl
100+
from mpl_toolkits.mplot3d import Axes3D
101+
102+
fig = pl.figure()
103+
ax = Axes3D(fig)
104+
105+
# Convert data for plot
106+
# List[all partitions] of List[all Point3D per partition]
107+
data_glom = rdd_part.glom().collect()
108+
109+
# Take only a few points (400 per partition) to speed-up
110+
# For each Sphere (el), takes the center and grab its coordinates and
111+
# make it a python list (it is JavaList by default)
112+
data_all = [
113+
np.array(
114+
[list(
115+
el.center().getCoordinatePython())
116+
for el in part[0:400]]).T
117+
for part in data_glom]
118+
119+
# Collect the radius sizes
120+
radius = [
121+
np.array(
122+
[el.outerRadius()
123+
for el in part[0:400]])
124+
for part in data_glom]
125+
126+
# Plot partition-by-partition
127+
for i in range(len(data_all)):
128+
s = radius[i] * 3000
129+
ax.scatter(data_all[i][0], data_all[i][1], data_all[i][2], s=s)
130+
131+
ax.set_xlabel("X")
132+
ax.set_ylabel("Y")
133+
ax.set_zlabel("Z")
134+
135+
# Save the result on disk
136+
if gridtype is not None:
137+
pl.savefig("octree_part_python.png")
138+
else:
139+
pl.savefig("octree_nopart_python.png")
140+
pl.show()
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2018 Julien Peloton
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pyspark.sql import SparkSession
15+
16+
import numpy as np
17+
18+
from pyspark3d import set_spark_log_level
19+
from pyspark3d import load_user_conf
20+
from pyspark3d import get_spark_session
21+
from pyspark3d import load_from_jvm
22+
from pyspark3d.spatial3DRDD import Point3DRDD
23+
24+
import argparse
25+
26+
def addargs(parser):
27+
""" Parse command line arguments for pyspark3d_part """
28+
29+
## Arguments
30+
parser.add_argument(
31+
'-inputpath', dest='inputpath',
32+
required=True,
33+
help='Path to a FITS file')
34+
35+
## Arguments
36+
parser.add_argument(
37+
'-hdu', dest='hdu',
38+
required=True,
39+
help='HDU index to load.')
40+
41+
## Arguments
42+
parser.add_argument(
43+
'-part', dest='part',
44+
default=None,
45+
help='Type of partitioning')
46+
47+
## Arguments
48+
parser.add_argument(
49+
'-npart', dest='npart',
50+
default=10,
51+
type=int,
52+
help='Number of partition')
53+
54+
## Arguments
55+
parser.add_argument(
56+
'--plot', dest='plot',
57+
action="store_true",
58+
help='Number of partition')
59+
60+
61+
if __name__ == "__main__":
62+
"""
63+
Re-partition RDD using ONION partitioning using pyspark3d
64+
"""
65+
parser = argparse.ArgumentParser(
66+
description="""
67+
Re-partition RDD using ONION partitioning using pyspark3d
68+
""")
69+
addargs(parser)
70+
args = parser.parse_args(None)
71+
72+
# Load user conf and Spark session
73+
dic = load_user_conf()
74+
spark = get_spark_session(dicconf=dic)
75+
76+
# Set logs to be quiet
77+
set_spark_log_level()
78+
79+
# Load raw data
80+
fn = args.inputpath
81+
rdd = Point3DRDD(
82+
spark, fn, "Z_COSMO,RA,DEC", True, "fits", {"hdu": args.hdu})
83+
84+
# Perform the re-partitioning
85+
npart = args.npart
86+
gridtype = args.part
87+
88+
if gridtype is not None:
89+
rdd_part = rdd.spatialPartitioningPython(gridtype, npart)
90+
else:
91+
rdd_part = rdd.rawRDD().toJavaRDD().repartition(npart)
92+
93+
if not args.plot:
94+
count = rdd_part.count()
95+
print("{} elements".format(count))
96+
else:
97+
# Plot the result
98+
# Collect the data on driver -- just for visualisation purpose, do not
99+
# do that with full data set or you will destroy your driver!!
100+
import pylab as pl
101+
from mpl_toolkits.mplot3d import Axes3D
102+
103+
fig = pl.figure()
104+
ax = Axes3D(fig)
105+
106+
# Converter from spherical to cartesian coordinate system
107+
# it takes a Point3D and return a Point3D
108+
mod = "com.astrolabsoftware.spark3d.utils.Utils.sphericalToCartesian"
109+
converter = load_from_jvm(mod)
110+
111+
# Convert data for plot -- List of List of Point3D
112+
data_glom = rdd_part.glom().collect()
113+
114+
# Take only a few points (400 per partition) to speed-up
115+
# For each Point3D (el), grab the coordinates, convert it from
116+
# spherical to cartesian coordinate system (for the plot) and
117+
# make it a python list (it is JavaList by default)
118+
data_all = [
119+
np.array(
120+
[list(
121+
converter(el).getCoordinatePython())
122+
for el in part[0:400]]).T
123+
for part in data_glom]
124+
125+
for i in range(len(data_all)):
126+
ax.scatter(data_all[i][0], data_all[i][1], data_all[i][2])
127+
128+
ax.set_xlabel("X")
129+
ax.set_ylabel("Y")
130+
ax.set_zlabel("Z")
131+
132+
if gridtype is not None:
133+
pl.savefig("onion_part_python.png")
134+
else:
135+
pl.savefig("onion_nopart_python.png")
136+
pl.show()

pic/pyspark3d_lib_0.2.1.png

172 KB
Loading

pic/spark3d_lib_0.2.1.png

188 KB
Loading

pyspark3d/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from typing import Any, List, Dict
1919

20+
import sys
21+
2022
from version import __version__
2123
from pyspark3d_conf import extra_jars, extra_packages, log_level
2224

@@ -288,4 +290,5 @@ def set_spark_log_level(log_level_manual=None):
288290
np.set_printoptions(legacy="1.13")
289291

290292
# Run the test suite
291-
doctest.testmod()
293+
failure_count, test_count = doctest.testmod()
294+
sys.exit(failure_count)

0 commit comments

Comments
 (0)