## Create a grid of plots in Matplotlib

Python Matplotlib Figure Create plot grid Luc B.

Python

Matplotlib

When preparing data for a report, plots need to be coalesced into concise figures that allow readers to quickly survey pertinent data. This is often acheived by including multiple plots in a single figure. Here we demostrate how to create a grid of plots.

### Code Example

Use the `plt.subplots()` function to create Axes objects for each plot in an arbitrary grid. `plt.subplots()` takes two arguments: the number of rows and columns in the grid.

Note that this code utilizes Matplotlib's object oriented interface. While the object oriented interface is preferred for complex figures like this, the second code example below demonstrates how to create the same figure with the pyplot interface.

``````import matplotlib.pyplot as plt
import numpy as np

# Create a 2x2 grid of plots
fig, axs = plt.subplots(2, 2, constrained_layout=True)

x = np.linspace(0, 1)

# Modify top-left plot
axs[0,0].set_title("Top Left")
axs[0,0].plot(x, x)

# Modify top-right plot
axs[0,1].set_title("Top Right")
axs[0,1].plot(x, x**2)

# Modify bottom-left plot
axs[1,0].set_title("Bottom Left")
axs[1,0].plot(x, np.sin(3*x))

# Modify bottom-right plot
axs[1,1].set_title("Bottom Right")
axs[1,1].plot(x, 1/(1+x))

plt.show()``````

`plt.subplots()` returns a Figure object and a 2D array of Axes objects. The array is structured intuitively: element [0, 0] is the top-left plot, and incrementing the left index or right index moves along rows and columns, respectively. To create a grid of different dimensions, simply pass different integers to `plt.subplots()`.

`constrained_layout=True` informs Matplotlib to automatically resize subplots so nothing overlaps.

### More Examples

#### Pyplot Interface

While the object oriented interface used above is the preferred approach to creating a grid of plots, pyplot offers the same functionality using the `plt.subplot()` function.

``````# Modify top-left plot
plt.subplot(221)
plt.title("Top Left")
plt.plot(x, x)

# Modify top-right plot
plt.subplot(222)
plt.title("Top Right")
plt.plot(x, x**2)

# Modify bottom-left plot here
plt.subplot(223)
plt.title("Bottom Left")
plt.plot(x, np.sin(3*x))

# Modify bottom-right plot here
plt.subplot(224)
plt.title("Bottom Right")
plt.plot(x, 1/(1+x))

# Recompute the plot layout so everything fits
plt.tight_layout()

plt.show()``````

The argument to `plt.subplot()` is a three digit number. The first digit represents the number of rows, the second digit represents the number of columns, and the third digit represents the current index in the grid.

The grid index starts at 1 and increases from left to right.

For example passing a `221` to `plt.subplot()` creates a figure with 2 rows and 2 columns, and all plotting operations that follow will apply to the plot at coordinate (1, 1). Passing `223` creates a figure with the same divisions, but plotting operations will modify the plot at coordinate (2, 1).