Skip to content

Commit

Permalink
New examples and new mapreduce function that uses scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
Darren Govoni committed Feb 28, 2022
1 parent c1a8f7d commit ab26e85
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 12 deletions.
1 change: 1 addition & 0 deletions blazer/examples/example3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ def add(x, y=0):
result = reduce(add, result)

blazer.print(result)

27 changes: 27 additions & 0 deletions blazer/examples/example5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import blazer
from blazer.hpc.mpi import mapreduce, reduce, rank


def sqr(x):
return x * x


def add(values):
if values and len(values):
return sum(values)
else:
return 0


with blazer.begin():
if blazer.ROOT:
data = list(range(0, 100))
print("DATA: ",data)
print("EXPECTING: ",sum(data))
else:
data = None

result = mapreduce(add, add, data)

blazer.print("RESULT:", result)

44 changes: 44 additions & 0 deletions blazer/examples/example6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import blazer
from blazer.hpc.mpi import parallel, pipeline, partial as p, scatter, where, select, filter, rank, size


def calc_some(value, *args):
""" Do some calculations """
result = {'some': value}
return result


def calc_stuff(value, *args):
""" Do some calculations """
result = {'this': value}
return result


def add_date(result):
from datetime import datetime
if type(result) is dict:
result['date'] = str(datetime.now())
return result


def calc_more_stuff(result):
""" Do some more calculations """
if type(result) is list:
result += [{'more': 'stuff'}]
elif type(result) is dict:
result['more'] = 'stuff'
return result


INPUT_DATA = 'that'

with blazer.begin():

def get_data():
""" Data generator """
for i in range(0, (size * 2)):
yield i


result = scatter(get_data(), calc_some)
blazer.print("SCATTER:", result)
3 changes: 2 additions & 1 deletion blazer/hpc/mpi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .primitives import (parallel, scatter, pipeline, reduce, map, rank, size)
from .primitives import (parallel, scatter, pipeline, reduce, map, mapreduce, rank, size)
from functools import partial
from pipe import select, where
from pydash import flatten, chunk, omit, get, filter_ as filter
Expand All @@ -9,6 +9,7 @@
'scatter',
'pipeline',
'map',
'mapreduce',
'reduce',
'partial',
'select',
Expand Down
50 changes: 39 additions & 11 deletions blazer/hpc/mpi/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,25 @@
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
host= MPI.Get_processor_name()

global loop
loop = True

MASTER = rank == 0

logging.info(f"MY RANK {rank} and size {size}")
logging.debug(f"MY {host} RANK {rank} and size {size}")


@contextmanager
def begin(*args, **kwds):
try:
logging.debug("Yielding comm")
yield comm
finally:
logging.debug("Invoking stop()")
stop()
logging.debug("Invoking stop(%s)",rank)
if rank == 0:
stop()


def mprint(*args):
Expand All @@ -41,8 +47,10 @@ def mprint(*args):


def stop():
""" Stop all workers """
global loop


logging.debug("Stopping %s, %s", rank,host)
if rank == 0:
logging.debug("Sending break to all ranks")
for i in range(1, size):
Expand All @@ -54,7 +62,7 @@ def stop():

if rank != 0:
def run():
while True:
while loop:
logging.debug("thread rank %s waiting on defer", rank)
defer = comm.recv(source=0)
logging.debug("thread rank %s got data %s", rank, defer)
Expand Down Expand Up @@ -145,6 +153,25 @@ def enumrate(gen):
i += 1


def mapreduce(_map: Callable, _reduce: Callable, data: Any):
""" Use scatter for map/reduce in one call """
import numpy as np

results = []
if rank == 0:
_chunks = np.array_split(data, size)
else:
_chunks = None
data = comm.scatter(_chunks, root=0)

_data = _map(data.tolist())
newData = comm.gather(_data, root=0)
results += [newData]
_flattened = flatten(results)
if None not in _flattened:
_data = _reduce(_flattened)
return _data

def map(func: Callable, data: Any):
_funcs = []
for arg in data:
Expand All @@ -162,10 +189,8 @@ def map(func: Callable, data: Any):
def reduce(func: Callable, data: Any):
_funcs = []
if data is None:
logging.debug("Returning None")
return None
for arg in data:
logging.debug("ARG %s", arg)
if iterable(arg):
_funcs += [partial(func, *arg)]
else:
Expand All @@ -179,18 +204,21 @@ def reduce(func: Callable, data: Any):


def scatter(data: Any, func: Callable):
""" Scatter """

""" This will create a generator to chunk the incoming data (which itself can be a generator)
Each chunk (which can itself be a list of data) will then be scattered with the function to all
ranks. """

def chunker(generator, chunksize):
chunk = []
for i, c in enumrate(generator):
chunk += [c]
if len(chunk) == size:
if len(chunk) == chunksize:
yield chunk
chunk = []
if len(chunk) > 0:
yield chunk


chunked_data = chunker(data, size)
results = []
for i, chunk in enumrate(chunked_data):
Expand All @@ -199,12 +227,12 @@ def chunker(generator, chunksize):

data = comm.scatter(chunk, root=0)
_data = func(data)
logging.debug("scatter[%s, %s]: Chunk %s %s, Func is %s Data is %s Result is %s", rank,host,i, chunk, func, data, _data)
newData = comm.gather(_data, root=0)
results += [newData]

return flatten(results)


def pipeline(defers: List, *args):
""" This will use the master node 0 scheduler to orchestrate results """
logging.debug("pipeline rank %s %s", rank, args)
Expand Down

0 comments on commit ab26e85

Please sign in to comment.