Create a grid of plots in Matplotlib

Python

Matplotlib

Figure

Create plot grid

Luc B.

Python

Matplotlib

When preparing data for a report, plots need to be coalesced into concise figures that allow readers to quickly survey pertinent data. This is often acheived by including multiple plots in a single figure. Here we demostrate how to create a grid of plots.

Note that this article only addresses uniform grids of plots. To create plots that span several rows or columns, check out this article.

Code Example

Use the plt.subplots() function to create Axes objects for each plot in an arbitrary grid. plt.subplots() takes two arguments: the number of rows and columns in the grid.

Note that this code utilizes Matplotlib's object oriented interface. While the object oriented interface is preferred for complex figures like this, the second code example below demonstrates how to create the same figure with the pyplot interface.

import matplotlib.pyplot as plt
import numpy as np

# Create a 2x2 grid of plots
fig, axs = plt.subplots(2, 2, constrained_layout=True)

x = np.linspace(0, 1)

# Modify top-left plot
axs[0,0].set_title("Top Left")
axs[0,0].plot(x, x)

# Modify top-right plot
axs[0,1].set_title("Top Right")
axs[0,1].plot(x, x**2)

# Modify bottom-left plot
axs[1,0].set_title("Bottom Left")
axs[1,0].plot(x, np.sin(3*x))

# Modify bottom-right plot
axs[1,1].set_title("Bottom Right")
axs[1,1].plot(x, 1/(1+x))

plt.show()

png

plt.subplots() returns a Figure object and a 2D array of Axes objects. The array is structured intuitively: element [0, 0] is the top-left plot, and incrementing the left index or right index moves along rows and columns, respectively. To create a grid of different dimensions, simply pass different integers to plt.subplots().

constrained_layout=True informs Matplotlib to automatically resize subplots so nothing overlaps.

More Examples

Pyplot Interface

While the object oriented interface used above is the preferred approach to creating a grid of plots, pyplot offers the same functionality using the plt.subplot() function.

# Modify top-left plot
plt.subplot(221)
plt.title("Top Left")
plt.plot(x, x)

# Modify top-right plot
plt.subplot(222)
plt.title("Top Right")
plt.plot(x, x**2)

# Modify bottom-left plot here
plt.subplot(223)
plt.title("Bottom Left")
plt.plot(x, np.sin(3*x))

# Modify bottom-right plot here
plt.subplot(224)
plt.title("Bottom Right")
plt.plot(x, 1/(1+x))

# Recompute the plot layout so everything fits
plt.tight_layout()

plt.show()

png

The argument to plt.subplot() is a three digit number. The first digit represents the number of rows, the second digit represents the number of columns, and the third digit represents the current index in the grid.

The grid index starts at 1 and increases from left to right.

For example passing a 221 to plt.subplot() creates a figure with 2 rows and 2 columns, and all plotting operations that follow will apply to the plot at coordinate (1, 1). Passing 223 creates a figure with the same divisions, but plotting operations will modify the plot at coordinate (2, 1).