An Introduction to MPI in Python - Part 1
Overview
MPI (Message Passing Interface) is a standard for parallel communication between processes and processors. It’s an abstraction that sits on top of the physical layer that enables engineers to write code that scales to a large number of processes, far beyond the core count of even the largest servers. Regardless of if the cores in use live on the same machine physically connected or require a wired/wireless connection to transmit data over TCP, an engineer can simply continue writing their code in the same fashion. As with all high-level APIs that abstract away the finer details however, this can present some performance challenges and significant latency spikes if used incorrectly. Throughout this blog series, we will move from the basic terminology to the basics of code before finally writing some scalalbe functions to calculate various statistics of large data sets. We’ll be taking a look at different approaches to tackling these problems as we build an understanding of how these bottlenecks occur and how to write more resilient code in the future.
As an aside, while this blog series is focused on MPI code, the principles discussed here also form the backbone for writing CUDA and ROCm code for GPUs. In fact, MPI is often used in tandem with GPU-level code when training LLMs due to the massive size of these neural networks. In these systems, different clusters of GPUs will often communicate to each other via MPI to sync data or state throughout the process.
Getting Set Up
I’ll be using python throughout this series. If you don’t already have this set up on your machine, I would highly recommend using conda (or your preferred package manager) to install mpi4py
and the MPI implementation of your choice (I’m using openmpi
).
conda install -c conda-forge mpi4py openmpi
For further setup instructions, I recommend following mpi4py installation documentation. You could alternatively follow along in your preferred language of choice. Once you have an MPI implementation installed, you’re set regardless of which language you want to use.
The Basic Terminology
We’ll start with a quick introduction of the basic MPI terminology and some key functions. For those with experience in CUDA programming, many of these names will sound familiar and carry similar meanings here.
comm
, size
, rank
, and root
The communicator is the object used to pass data between workers. We’ll just be using the WORLD
communicator, which spans all processes in your environment. By default, this will just be the number of cores you specify when launching the program from the command line to execute it. You can create custom COMM objects if desired for more fine-tuned control or configure your environment to have multiple nodes, but that is beyond the scope of our introductory lesson.
size
refers to the number of processes in your environment, and rank
is the index of the process. The root
is generally assumed to be rank 0. If your environment is set up correctly, then the following code will print out the rank of the process executing it.
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
print(f'Hello, rank {rank} of {rank}!')
You can launch this code with the following command (omit -n 2
to run with the maximum number of processes in your environment):
mpiexec -n 2 python mpi_hello_world.py
It can be tricky to remember, but every process executes each line of code. Also, since you can’t control which process will execute first, you will likely see different orderings of the print statements after running the code several times. Even just this simple example reveals the difficulty in writing (and debugging!) highly parallelized code. Note that for future excerpts I will omit the redundant headers seen here for simplicity.
send
and recv
The simplest commands to send data from one process to another. In reality, most MPI code won’t use these functions directly and instead will use higher level functions that either go from one-to-many, from many-to-one, or from many-to-many. tag
can be used to make sure that you know exactly what message is being sent/received at any given time (since different processes might be sending different messages at different times), but they might be redundant depending on your use case. Any picklable python object (e.g. ints or dicts) may be sent without hassle. mpi4py
under the hood handles alerting the receiving process to the length of the payload so that the corresponding buffer is sized appropriately.
-
assert size == 2 # only designed for two processes ground_truth = 5 if rank == 0: val = ground_truth comm.send(val, dest=1, tag=0) else: val = comm.recv(source=0, tag=0) assert val == ground_truth
-
assert size == 2 # only designed for two processes ground_truth = {'a' : 1, 'b' : 2} if rank == 0: d = ground_truth comm.send(d, dest=1, tag=0) else: d = comm.recv(source=0, tag=0) assert d == ground_truth
barrier
Barriers are forced synchronization points (also known as “blocking” calls) and must be triggered by every process in the environment before they can proceed. The next few functions we will see also block as they share information, but barrier
is unique since there is no exchange of data. In general, it is used to prevent race conditions for shared resources (for instance isolating read and write steps so that your system doesn’t hit a deadlock).
comm.barrier()
bcast
and reduce
Broadcast and reduce are two of the most important MPI functions, especially when it comes to opreating on matrices. Broadcast is a one-to-many function that takes data from a given rank (usually the root) and syncs the value with all other processes. Reduce is a many-to-one function that syncs data from all ranks, performs an aggregation operation (sum/product, min/max, logical/bitwise operations, etc.). You can even define custom operations to apply during the reduction, but it is a good rule of thumb to ensure these functions as commutative and associative operations.
-
val = None if rank == 0: val = 0 val = comm.bcast(val, root=0) sumRanks = comm.reduce(rank + val, op=MPI.SUM, root=0) if rank == 0: assert sumRanks == sum(range(size)) # Only check on the root node
-
# MPI.SUM doesn't support dicts def dict_reduce(dict1, dict2): return {'v' : dict1['v'] + dict2['v']} d = {} if rank == 0: d = {'v' : 0} d = comm.bcast(d, root=0) d['v'] = rank sumRanks = comm.reduce(d, op=dict_reduce, root=0) if rank == 0: assert sumRanks['v'] == sum(range(size)) # Only check on the root node
gather
and scatter
The other key data synchronozation functions are gather
and scatter
. gather
takes a data object from each worker and syncs it to a root while scatter
takes a list of data objects with length equal to the number of processes in the environment and sends the corresponding object to each process.
-
data_obj = rank gathered_objs = comm.gather(data_obj, root=0) if rank == 0: assert type(gathered_objs) is list assert sum(gathered_objs) == sum(range(size)) objs_to_scatter = None if rank == 0: objs_to_scatter = [val ** 2 for val in gathered_objs] new_data_obj = comm.scatter(objs_to_scatter, root=0) assert new_data_obj == rank ** 2
-
data_obj = {'a' : 0, 'b' : rank} gathered_objs = comm.gather(data_obj, root=0) if rank == 0: assert type(gathered_objs) is list assert [d['b'] for d in gathered_objs] == list(range(size)) scattered_objs = None if rank == 0: scattered_objs = gathered_objs.copy() for d in scattered_objs: d['b'] **= 2 new_data_obj = comm.scatter(scattered_objs, root=0) assert new_data_obj['b'] == rank ** 2
all-
This prefix can be applied to reduce
and gather
operations to indicate that the output goes to all workers rather than just to the root node. allscatter
technically exists but is called alltoall
. For now, we’ll just use the example of allreduce
:
-
val = None if rank == 0: val = 0 val = comm.bcast(val, root=0) sumRanks = comm.allreduce(rank + val, op=MPI.SUM) # no longer pass in a root assert sumRanks == sum(range(size))
-
# MPI.SUM doesn't support dicts def dict_reduce(dict1, dict2): return {'v' : dict1['v'] + dict2['v']} d = {} if rank == 0: d = {'v' : 0} d = comm.bcast(d, root=0) d['v'] = rank sumRanks = comm.allreduce(d, op=dict_reduce) # no longer pass in a root assert sumRanks['v'] == sum(range(size))
It’s worth considering the operational cost of all-
operations. In a normal reduce
, each process performs a single send operation, and the root performs n
receives. In total, we have O(n)
operations. However, for an allreduce
, each process performs n
sends and n
receives, meaning that we now have a total of O(n^2)
operations in total. In general, this is fine in terms of complexity scaling since each workers is still only doing a linear amount of work, but we should be cognizant of the fact that the number of “messages” flying throughout our system scales quadratically with the number of processes in the environment. In practice however, MPI implementations will optimize these all-
operations to reduce the total number of messages being sent and reduce queueing to help mitigate this effect on large environments.
mpi4py Idiosyncracies
We’ve now covered all of the basic functions, but mpi4py includes vectorized versions of these functions for numpy arrays (even for Barrier
despite there being no data transfer). The vectorized implementations are also distinct in that the API is slightly different - whereas the prior set of data transfer functions returned values, the vectorized implementations write to a data buffer instead. It’s also very important that you pay close attention to the typing of these numpy arrays or you will hit an error when reading from the buffer. For those curious, the vectorized mpi4py API is more similar to that which you would see in C++. To the degree that it is possible, it’s best practice to try and use either only vectorized or non-vectorized functions for code sustainability.
-
assert size == 2 # only designed for two processes ground_truth = np.arange(5, dtype='i') if rank == 0: send_arr = ground_truth comm.Send(send_arr, dest=1, tag=0) else: recv_arr = np.empty(5, dtype='i') comm.Recv(recv_arr, source=0, tag=0) assert np.all(recv_arr == ground_truth)
-
assert size == 2 # only designed for two processes ground_truth = np.arange(5, dtype='l') # Sending array<long> if rank == 0: send_arr = ground_truth comm.Send(send_arr, dest=1, tag=0) else: recv_arr = np.empty(5, dtype='i') # Expecting array<int> comm.Recv(recv_arr, source=0, tag=0) # Hits an MPI_ERR_TRUNCATE exception assert np.all(recv_arr == ground_truth)
-
arr = np.empty(2) if rank == 0: arr = np.arange(2, dtype='d') comm.Bcast(arr, root=0) sumRanks = np.zeros(2) comm.Reduce(rank + arr, sumRanks, op=MPI.MAX, root=0) # Using MAX reduce this time if rank == 0: assert np.all(sumRanks == arr + size - 1) else: assert np.all(sumRanks == 0)
-
data_arr = np.zeros(3) + rank gathered_arrs = np.empty([size, 3]) comm.Gather(data_arr, gathered_arrs, root=0) if rank == 0: assert len(gathered_arrs.shape) == 2 assert np.all(gathered_arrs.T == np.arange(size)) scattered_objs = None new_data_arr = np.empty(3) if rank == 0: scattered_objs = gathered_arrs * np.arange(3) comm.Scatter(scattered_objs, new_data_arr, root=0) assert np.all(new_data_arr == data_arr * np.arange(3))
-
arr = np.empty(2) if rank == 0: arr = np.arange(2, dtype='d') comm.Bcast(arr, root=0) sumRanks = np.zeros(2) comm.Allreduce(rank + arr, sumRanks, op=MPI.MAX) # Using MAX reduce this time assert np.all(sumRanks == arr + size - 1)
Blocking and non-blocking calls
Until now, we’ve focused on blocking calls like Send
and Recv
, but there are also non-blocking functions like Isend
and Irecv
(and correspondong non-numpy functions as well) that return immediately (hence the name). These functions return MPI_Request
and enable the send/receive process to continue operating while the communication occurs in the background. You can call req.wait()
on these objects (or MPI.Request.Waitall
on a list of them) that will return when the message transfer is complete and the buffer (on either the send or receive side) is safe to interact with again. It’s worth noting as well that you don’t need to pair Isend
with Irecv
or vice versa either - you can use a blocking send with a non-blocking receive (or vice versa) if that makes most sense in your program. Here is a quick example showing how these functions operate and what timing looks like for them.
assert size == 2
buf_size = 1000
if rank == 0:
arr = np.zeros(buf_size) + 10
req = comm.Isend(arr, dest=1)
else:
arr = np.empty(buf_size)
req = comm.Irecv(arr, source=0)
req.Wait()
if rank == 1:
assert np.all(arr == 10)
Like before, there are also non-blocking versions of collective functions in addition to the point-to-point ones we just looked at. I’ll use Ireduce
for this example, but know that you can do the same thing for the other functions we’ve seen as well. A practical scenario for this would be if you have a parallelized process doing heavy workloads and needing to sync a relatively large amount of data to all other nodes (these scenarios come up frequently with large matrix multiplications). Another use case would be if you have slightly different amounts of work being done on each node, you can end up with each node taking a different amount of time for each stage. If there’s a single node causing the bottleneck, then there’s nothing you can do other than adding more nodes. However, if there’s a little variability in the run time (which in practice happens quite frequently), then you might benefit from using non-blocking communications to reduce the idle time of other processes. We simulate that here by having each node wait at different iterations.
import time
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
def test(comm, func):
rank = comm.Get_rank()
size = comm.Get_size()
arr_size = 10000
res = None
if rank == 0:
res = np.zeros([size,size,arr_size], dtype='f')
start = MPI.Wtime()
requests = []
for i in range(size):
if rank == i:
time.sleep(1)
arr = np.zeros(arr_size, dtype='f') + rank * i
if rank == 0:
requests.append(func(arr, res[i,:,:], root=0))
else:
requests.append(func(arr, None, root=0))
if requests[0] is not None:
MPI.Request.Waitall(requests)
end = MPI.Wtime()
if rank == 0:
print(f'({func.__name__}) Total time = {end - start}')
cumsum = np.arange(size).reshape(1,-1)
assert np.all(res == (cumsum.T @ cumsum).reshape(size,size,1))
test(comm, comm.Igather)
test(comm, comm.Gather)
Running the above code with 8 nodes will take ~1s for the non-blocking routine but ~8s for the blocking version. There’s a reason that I make sure that the data is sufficiently large here, but I’ll save this for the next post.
Challenge
Before moving on to the next article, I strongly encourage you to try the following exercise. It will only take a few lines of code but will further help to cement your understanding of how to safely and successfully implement MPI code while avoiding deadlocks. Make sure your implementation works with varying core counts (from 2 to the max number supported in your environment). The challenge has 3 parts:
- Implement a scalable
alltoall
function WITH a deadlock while only usingSend/Recv
for sufficiently large messages - Correct the deadlock from part 1 while still only using
Send/Recv
- Using your solution to part 1, is it possible to resolve the deadlock by instead using
Isend/Recv
?
You may use the following scaffolding code if you’d like. The payload must be scalable in size in this example for reasons I will explain next time, so make sure to test your solutions with variying scaling_factor
sizes. I’ll post sample solutions to these challenges at the beginning of the next blog.
# mpiexec -n 2 python challenge1.py
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
def alltoall(comm, values):
rank = comm.Get_rank()
size = comm.Get_size()
# Enter Solution Here
pass
scaling_factor = 10000
values = [np.array([rank * 10 + i] * scaling_factor, dtype='l') for i in range(size)]
new_values = alltoall(comm, values)
ground_truth = [[rank + i * 10] * scaling_factor for i in range(size)]
assert np.all([np.all(v1 == v2) for v1,v2 in zip(new_values, ground_truth)])
Conclusion
This is just a taste of the functions available through the MPI API. If you’re interested in learning about more advanced and niche functions, you should take a look at the openmpi or mpi4py documentation! There are plenty more buffered and specialized functions available for all of your potentialy use cases. Next time, we’ll go over the solutions to challenge problems here and let them guide our dissection of the nuances of blocking and non-blocking functions. We’ll also learn a bit more about some of the caveats that I mentioned today (namely why the size of the message has a much bigger impact on code design than you might expect) as well as a few more variations of the send
function. Thanks for joining me, and I hope you learned something new and are excited to dive even deeper next time!