Greg McCord
  • about
  • blog (current)
  • research
  • projects
  • repositories

An Introduction to MPI in Python - Part 1

Created on March 18, 2025   ·   20 min read

2025   ·   python   mpi   ·   code

  • Overview
  • Getting Set Up
  • The Basic Terminology
    • comm, size, rank, and root
    • send and recv
    • barrier
    • bcast and reduce
    • gather and scatter
    • all-
  • mpi4py Idiosyncracies
  • Blocking and non-blocking calls
  • Challenge
  • Conclusion

Overview

MPI (Message Passing Interface) is a standard for parallel communication between processes and processors. It’s an abstraction that sits on top of the physical layer that enables engineers to write code that scales to a large number of processes, far beyond the core count of even the largest servers. Regardless of if the cores in use live on the same machine physically connected or require a wired/wireless connection to transmit data over TCP, an engineer can simply continue writing their code in the same fashion. As with all high-level APIs that abstract away the finer details however, this can present some performance challenges and significant latency spikes if used incorrectly. Throughout this blog series, we will move from the basic terminology to the basics of code before finally writing some scalalbe functions to calculate various statistics of large data sets. We’ll be taking a look at different approaches to tackling these problems as we build an understanding of how these bottlenecks occur and how to write more resilient code in the future.

As an aside, while this blog series is focused on MPI code, the principles discussed here also form the backbone for writing CUDA and ROCm code for GPUs. In fact, MPI is often used in tandem with GPU-level code when training LLMs due to the massive size of these neural networks. In these systems, different clusters of GPUs will often communicate to each other via MPI to sync data or state throughout the process.

Getting Set Up

I’ll be using python throughout this series. If you don’t already have this set up on your machine, I would highly recommend using conda (or your preferred package manager) to install mpi4py and the MPI implementation of your choice (I’m using openmpi).

conda install -c conda-forge mpi4py openmpi

For further setup instructions, I recommend following mpi4py installation documentation. You could alternatively follow along in your preferred language of choice. Once you have an MPI implementation installed, you’re set regardless of which language you want to use.

The Basic Terminology

We’ll start with a quick introduction of the basic MPI terminology and some key functions. For those with experience in CUDA programming, many of these names will sound familiar and carry similar meanings here.

comm, size, rank, and root

The communicator is the object used to pass data between workers. We’ll just be using the WORLD communicator, which spans all processes in your environment. By default, this will just be the number of cores you specify when launching the program from the command line to execute it. You can create custom COMM objects if desired for more fine-tuned control or configure your environment to have multiple nodes, but that is beyond the scope of our introductory lesson.

size refers to the number of processes in your environment, and rank is the index of the process. The root is generally assumed to be rank 0. If your environment is set up correctly, then the following code will print out the rank of the process executing it.

from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

print(f'Hello, rank {rank} of {rank}!')

You can launch this code with the following command (omit -n 2 to run with the maximum number of processes in your environment):

mpiexec -n 2 python mpi_hello_world.py

It can be tricky to remember, but every process executes each line of code. Also, since you can’t control which process will execute first, you will likely see different orderings of the print statements after running the code several times. Even just this simple example reveals the difficulty in writing (and debugging!) highly parallelized code. Note that for future excerpts I will omit the redundant headers seen here for simplicity.

send and recv

The simplest commands to send data from one process to another. In reality, most MPI code won’t use these functions directly and instead will use higher level functions that either go from one-to-many, from many-to-one, or from many-to-many. tag can be used to make sure that you know exactly what message is being sent/received at any given time (since different processes might be sending different messages at different times), but they might be redundant depending on your use case. Any picklable python object (e.g. ints or dicts) may be sent without hassle. mpi4py under the hood handles alerting the receiving process to the length of the payload so that the corresponding buffer is sized appropriately.

  • int
  • dict
  • assert size == 2 # only designed for two processes
    
    ground_truth = 5
    if rank == 0:
        val = ground_truth
        comm.send(val, dest=1, tag=0)
    else:
        val = comm.recv(source=0, tag=0)
        assert val == ground_truth
    
  • assert size == 2 # only designed for two processes
    
    ground_truth = {'a' : 1, 'b' : 2}
    if rank == 0:
        d = ground_truth
        comm.send(d, dest=1, tag=0)
    else:
        d = comm.recv(source=0, tag=0)
        assert d == ground_truth
    

barrier

Barriers are forced synchronization points (also known as “blocking” calls) and must be triggered by every process in the environment before they can proceed. The next few functions we will see also block as they share information, but barrier is unique since there is no exchange of data. In general, it is used to prevent race conditions for shared resources (for instance isolating read and write steps so that your system doesn’t hit a deadlock).

comm.barrier()

bcast and reduce

Broadcast and reduce are two of the most important MPI functions, especially when it comes to opreating on matrices. Broadcast is a one-to-many function that takes data from a given rank (usually the root) and syncs the value with all other processes. Reduce is a many-to-one function that syncs data from all ranks, performs an aggregation operation (sum/product, min/max, logical/bitwise operations, etc.). You can even define custom operations to apply during the reduction, but it is a good rule of thumb to ensure these functions as commutative and associative operations.

  • int
  • dict
  • val = None
    if rank == 0:
        val = 0
    val = comm.bcast(val, root=0)
    sumRanks = comm.reduce(rank + val, op=MPI.SUM, root=0)
    if rank == 0:
        assert sumRanks == sum(range(size)) # Only check on the root node
    
  • # MPI.SUM doesn't support dicts
    def dict_reduce(dict1, dict2):
        return {'v' : dict1['v'] + dict2['v']}
    
    d = {}
    if rank == 0:
        d = {'v' : 0}
    d = comm.bcast(d, root=0)
    d['v'] = rank
    sumRanks = comm.reduce(d, op=dict_reduce, root=0)
    if rank == 0:
        assert sumRanks['v'] == sum(range(size)) # Only check on the root node
    

gather and scatter

The other key data synchronozation functions are gather and scatter. gather takes a data object from each worker and syncs it to a root while scatter takes a list of data objects with length equal to the number of processes in the environment and sends the corresponding object to each process.

  • int
  • dict
  • data_obj = rank
    gathered_objs = comm.gather(data_obj, root=0)
    if rank == 0:
        assert type(gathered_objs) is list
        assert sum(gathered_objs) == sum(range(size))
    
    objs_to_scatter = None
    if rank == 0:
        objs_to_scatter = [val ** 2 for val in gathered_objs]
    new_data_obj = comm.scatter(objs_to_scatter, root=0)
    assert new_data_obj == rank ** 2
    
  • data_obj = {'a' : 0, 'b' : rank}
    gathered_objs = comm.gather(data_obj, root=0)
    if rank == 0:
        assert type(gathered_objs) is list
        assert [d['b'] for d in gathered_objs] == list(range(size))
    
    scattered_objs = None
    if rank == 0:
        scattered_objs = gathered_objs.copy()
        for d in scattered_objs:
            d['b'] **= 2
    new_data_obj = comm.scatter(scattered_objs, root=0)
    assert new_data_obj['b'] == rank ** 2
    

all-

This prefix can be applied to reduce and gather operations to indicate that the output goes to all workers rather than just to the root node. allscatter technically exists but is called alltoall. For now, we’ll just use the example of allreduce:

  • int
  • dict
  • val = None
    if rank == 0:
        val = 0
    val = comm.bcast(val, root=0)
    sumRanks = comm.allreduce(rank + val, op=MPI.SUM) # no longer pass in a root
    assert sumRanks == sum(range(size))
    
  • # MPI.SUM doesn't support dicts
    def dict_reduce(dict1, dict2):
        return {'v' : dict1['v'] + dict2['v']}
    
    d = {}
    if rank == 0:
        d = {'v' : 0}
    d = comm.bcast(d, root=0)
    d['v'] = rank
    sumRanks = comm.allreduce(d, op=dict_reduce) # no longer pass in a root
    assert sumRanks['v'] == sum(range(size))
    

It’s worth considering the operational cost of all- operations. In a normal reduce, each process performs a single send operation, and the root performs n receives. In total, we have O(n) operations. However, for an allreduce, each process performs n sends and n receives, meaning that we now have a total of O(n^2) operations in total. In general, this is fine in terms of complexity scaling since each workers is still only doing a linear amount of work, but we should be cognizant of the fact that the number of “messages” flying throughout our system scales quadratically with the number of processes in the environment. In practice however, MPI implementations will optimize these all- operations to reduce the total number of messages being sent and reduce queueing to help mitigate this effect on large environments.

mpi4py Idiosyncracies

We’ve now covered all of the basic functions, but mpi4py includes vectorized versions of these functions for numpy arrays (even for Barrier despite there being no data transfer). The vectorized implementations are also distinct in that the API is slightly different - whereas the prior set of data transfer functions returned values, the vectorized implementations write to a data buffer instead. It’s also very important that you pay close attention to the typing of these numpy arrays or you will hit an error when reading from the buffer. For those curious, the vectorized mpi4py API is more similar to that which you would see in C++. To the degree that it is possible, it’s best practice to try and use either only vectorized or non-vectorized functions for code sustainability.

  • Send/Recv
  • Send/Recv Error
  • Bcast/Reduce
  • Gather/Scatter
  • Allreduce
  • assert size == 2 # only designed for two processes
    
    ground_truth = np.arange(5, dtype='i')
    if rank == 0:
        send_arr = ground_truth
        comm.Send(send_arr, dest=1, tag=0)
    else:
        recv_arr = np.empty(5, dtype='i')
        comm.Recv(recv_arr, source=0, tag=0)
        assert np.all(recv_arr == ground_truth)
    
  • assert size == 2 # only designed for two processes
    
    ground_truth = np.arange(5, dtype='l')   # Sending array<long>
    if rank == 0:
        send_arr = ground_truth
        comm.Send(send_arr, dest=1, tag=0)
    else:
        recv_arr = np.empty(5, dtype='i')    # Expecting array<int>
        comm.Recv(recv_arr, source=0, tag=0) # Hits an MPI_ERR_TRUNCATE exception
        assert np.all(recv_arr == ground_truth)
    
  • arr = np.empty(2)
    if rank == 0:
        arr = np.arange(2, dtype='d')
    comm.Bcast(arr, root=0)
    sumRanks = np.zeros(2)
    comm.Reduce(rank + arr, sumRanks, op=MPI.MAX, root=0) # Using MAX reduce this time
    if rank == 0:
        assert np.all(sumRanks == arr + size - 1)
    else:
        assert np.all(sumRanks == 0)
    
  • data_arr = np.zeros(3) + rank
    gathered_arrs = np.empty([size, 3])
    comm.Gather(data_arr, gathered_arrs, root=0)
    if rank == 0:
        assert len(gathered_arrs.shape) == 2
        assert np.all(gathered_arrs.T == np.arange(size))
    
    scattered_objs = None
    new_data_arr = np.empty(3)
    if rank == 0:
        scattered_objs = gathered_arrs * np.arange(3)
    comm.Scatter(scattered_objs, new_data_arr, root=0)
    assert np.all(new_data_arr == data_arr * np.arange(3))
    
  • arr = np.empty(2)
    if rank == 0:
        arr = np.arange(2, dtype='d')
    comm.Bcast(arr, root=0)
    sumRanks = np.zeros(2)
    comm.Allreduce(rank + arr, sumRanks, op=MPI.MAX) # Using MAX reduce this time
    assert np.all(sumRanks == arr + size - 1)
    

Blocking and non-blocking calls

Until now, we’ve focused on blocking calls like Send and Recv, but there are also non-blocking functions like Isend and Irecv (and correspondong non-numpy functions as well) that return immediately (hence the name). These functions return MPI_Request and enable the send/receive process to continue operating while the communication occurs in the background. You can call req.wait() on these objects (or MPI.Request.Waitall on a list of them) that will return when the message transfer is complete and the buffer (on either the send or receive side) is safe to interact with again. It’s worth noting as well that you don’t need to pair Isend with Irecv or vice versa either - you can use a blocking send with a non-blocking receive (or vice versa) if that makes most sense in your program. Here is a quick example showing how these functions operate and what timing looks like for them.

assert size == 2

buf_size = 1000

if rank == 0:
    arr = np.zeros(buf_size) + 10
    req = comm.Isend(arr, dest=1)
else:
    arr = np.empty(buf_size)
    req = comm.Irecv(arr, source=0)

req.Wait()

if rank == 1:
    assert np.all(arr == 10)

Like before, there are also non-blocking versions of collective functions in addition to the point-to-point ones we just looked at. I’ll use Ireduce for this example, but know that you can do the same thing for the other functions we’ve seen as well. A practical scenario for this would be if you have a parallelized process doing heavy workloads and needing to sync a relatively large amount of data to all other nodes (these scenarios come up frequently with large matrix multiplications). Another use case would be if you have slightly different amounts of work being done on each node, you can end up with each node taking a different amount of time for each stage. If there’s a single node causing the bottleneck, then there’s nothing you can do other than adding more nodes. However, if there’s a little variability in the run time (which in practice happens quite frequently), then you might benefit from using non-blocking communications to reduce the idle time of other processes. We simulate that here by having each node wait at different iterations.

import time
import numpy as np
from mpi4py import MPI

comm = MPI.COMM_WORLD
size = comm.Get_size()

def test(comm, func):
    rank = comm.Get_rank()
    size = comm.Get_size()

    arr_size = 10000

    res = None
    if rank == 0:
        res = np.zeros([size,size,arr_size], dtype='f')

    start = MPI.Wtime()

    requests = []
    for i in range(size):
        if rank == i:
            time.sleep(1)
        arr = np.zeros(arr_size, dtype='f') + rank * i
        if rank == 0:
            requests.append(func(arr, res[i,:,:], root=0))
        else:
            requests.append(func(arr, None, root=0))

    if requests[0] is not None:
        MPI.Request.Waitall(requests)

    end = MPI.Wtime()

    if rank == 0:
        print(f'({func.__name__}) Total time = {end - start}')
        cumsum = np.arange(size).reshape(1,-1)
        assert np.all(res == (cumsum.T @ cumsum).reshape(size,size,1))

test(comm, comm.Igather)
test(comm, comm.Gather)

Running the above code with 8 nodes will take ~1s for the non-blocking routine but ~8s for the blocking version. There’s a reason that I make sure that the data is sufficiently large here, but I’ll save this for the next post.

Challenge

Before moving on to the next article, I strongly encourage you to try the following exercise. It will only take a few lines of code but will further help to cement your understanding of how to safely and successfully implement MPI code while avoiding deadlocks. Make sure your implementation works with varying core counts (from 2 to the max number supported in your environment). The challenge has 3 parts:

  1. Implement a scalable alltoall function WITH a deadlock while only using Send/Recv for sufficiently large messages
  2. Correct the deadlock from part 1 while still only using Send/Recv
  3. Using your solution to part 1, is it possible to resolve the deadlock by instead using Isend/Recv?

You may use the following scaffolding code if you’d like. The payload must be scalable in size in this example for reasons I will explain next time, so make sure to test your solutions with variying scaling_factor sizes. I’ll post sample solutions to these challenges at the beginning of the next blog.

# mpiexec -n 2 python challenge1.py

from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

def alltoall(comm, values):
    rank = comm.Get_rank()
    size = comm.Get_size()

    # Enter Solution Here
    pass

scaling_factor = 10000
values = [np.array([rank * 10 + i] * scaling_factor, dtype='l') for i in range(size)]
new_values = alltoall(comm, values)
ground_truth = [[rank + i * 10] * scaling_factor for i in range(size)]
assert np.all([np.all(v1 == v2) for v1,v2 in zip(new_values, ground_truth)])

Conclusion

This is just a taste of the functions available through the MPI API. If you’re interested in learning about more advanced and niche functions, you should take a look at the openmpi or mpi4py documentation! There are plenty more buffered and specialized functions available for all of your potentialy use cases. Next time, we’ll go over the solutions to challenge problems here and let them guide our dissection of the nuances of blocking and non-blocking functions. We’ll also learn a bit more about some of the caveats that I mentioned today (namely why the size of the message has a much bigger impact on code design than you might expect) as well as a few more variations of the send function. Thanks for joining me, and I hope you learned something new and are excited to dive even deeper next time!


Next
An Introduction to MPI in Python - Part 2


© Copyright 2025 Greg McCord. Powered by Jekyll with al-folio. Hosted by GitHub Pages.