Subplots in Matplotlib
Learn to add subplots in Matplotlib plots in Python.

Topics Covered
- What are subplots in Matplotlib?
- Syntax of subplots in Matplotlib
- How to add subplots in Matplotlib?
- How to access subplots in Matplotlib?
- How to have shared axis for Matplotlib subplots?
What are subplots in Matplotlib?
Subplots in Matplotlib refer to having multiple plots inside a single Matplotlib figure. Subplots help you analyze multiple plots side-by-side, demonstrating different aspects of a problem.
To build subplots using Matplotlib, use the subplots()
function from the Pyplot
class in Matplotlib.
Syntax of subplots in Matplotlib
matplotlib.pyplot.subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw)
Parameters
-
nrows, ncols - These parameter are the number of rows/columns of the subplot grid.
-
sharex, sharey : These parameter controls sharing of properties among x (sharex) or y (sharey) axes.
-
squeeze : This parameter is an optional parameter and it contains boolean value with default as True.
-
num: This parameter is the pyplot.figure keyword that sets the figure number or label.
-
subplot_kwd: This parameter is the dict with keywords passed to the add_subplot call used to create each subplot.
-
gridspec_kw: This parameter is the dict with keywords passed to the GridSpec constructor used to create the grid the subplots are placed on.
Returns: This method return the following values.
fig : This method return the figure layout.
ax : This method return the axes.Axes object or array of Axes objects.
How to add subplots in Matplotlib?
Here is a simple example demonstrating how to build a subplot in Matplotlib.
import matplotlib.pyplot as plt
import numpy as np
data = {'apple': 10, 'orange': 15, 'lemon': 5, 'lime': 20}
names = list(data.keys())
values = list(data.values())
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(9, 3), sharey=True)
axs[0].bar(names, values)
axs[1].scatter(names, values)
axs[2].plot(names, values)
fig.suptitle('Plotting different aspects of fruits')
plt.show()

The above example adds subplots linearly. You can arrange the subplots anyway you want.
Lets have Matplotlib subplots across 2 rows and 2 columns.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 1000)
fig, axs = plt.subplots(nrows=2, ncols=2)
axs[0,0].plot(x, np.sin(x), '-b', label='Sine')
axs[0,0].plot(x, np.cos(x), '--r', label='Cosine')
axs[0,0].axis('equal')
axs[0,0].legend(loc='upper left')
axs[0,1].plot(x, np.sin(x), '-b', label='Sine')
axs[0,1].plot(x, np.cos(x), '--r', label='Cosine')
axs[0,1].axis('equal')
axs[0,1].legend(loc='upper right')
axs[1,0].plot(x, np.sin(x), '-b', label='Sine')
axs[1,0].plot(x, np.cos(x), '--r', label='Cosine')
axs[1,0].axis('equal')
axs[1,0].legend(loc='upper left')
axs[1,1].plot(x, np.sin(x), '-b', label='Sine')
axs[1,1].plot(x, np.cos(x), '--r', label='Cosine')
axs[1,1].axis('equal')
axs[1,1].legend(loc='upper right')
plt.show()

How to access subplots in Matplotlib?
Accessing subplots is similar to accessing elements from a 2D array.
axs[0][0]
means first row (index 0) and the first plot from that row (index 0).ax[1][1]
means the second row (index 1) and the second element from that row (index 1).
Let's use the above logic to draw plots only on the first and the last plots from the above example.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 1000)
fig, axs = plt.subplots(nrows=2, ncols=2)
axs[0,0].plot(x, np.sin(x), '-b', label='Sine')
axs[0,0].plot(x, np.cos(x), '--r', label='Cosine')
axs[0,0].axis('equal')
axs[0,0].legend(loc='upper left')
axs[1,1].plot(x, np.sin(x), '-b', label='Sine')
axs[1,1].plot(x, np.cos(x), '--r', label='Cosine')
axs[1,1].axis('equal')
axs[1,1].legend(loc='upper right')
plt.show()

How to have shared axis for Matplotlib subplots?
In many plots, it would be convenient to have the axis of subplots aligned with each other. The matplotlib subplots()
method accepts two more arguments namely sharex
and sharey
for the same purpose.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 1000)
fig, axs = plt.subplots(nrows=2, ncols=2, sharex=True)
axs[0,0].plot(x, np.sin(x), '-b', label='Sine')
axs[0,0].plot(x, np.cos(x), '--r', label='Cosine')
axs[0,0].axis('equal')
axs[0,0].legend(loc='upper left')
axs[0,1].plot(x, np.sin(x), '-b', label='Sine')
axs[0,1].plot(x, np.cos(x), '--r', label='Cosine')
axs[0,1].axis('equal')
axs[0,1].legend(loc='upper right')
axs[1,0].plot(x, np.sin(x), '-b', label='Sine')
axs[1,0].plot(x, np.cos(x), '--r', label='Cosine')
axs[1,0].axis('equal')
axs[1,0].legend(loc='lower left')
axs[1,1].plot(x, np.sin(x), '-b', label='Sine')
axs[1,1].plot(x, np.cos(x), '--r', label='Cosine')
axs[1,1].axis('equal')
axs[1,1].legend(loc='lower right')
plt.show()

Pylenin has a dedicated Youtube playlist for Matplotlib Tutorial. Check out our entire Matplotlib playlist here.