19

If I have the following code:

import seaborn 
import matplotlib.pyplot as plt
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
f,(ax1,ax2,ax3) = plt.subplots(1,3,sharey=True)
g1 = sns.heatmap(flights,cmap="YlGnBu",cbar=False,ax=ax1)
g1.set_ylabel('')
g1.set_xlabel('')
g2 = sns.heatmap(flights,cmap="YlGnBu",cbar=False,ax=ax2)
g2.set_ylabel('')
g2.set_xlabel('')
g3 = sns.heatmap(flights,cmap="YlGnBu",ax=ax3)
g3.set_ylabel('')
g3.set_xlabel('')

Which outputs the following - enter image description here

How can I adjust the subplots so that the g3 axis is the same width as the g1,g2 axis. Since I have not added the color bar to the first two axis', seaborn shrinks the third axis down to make the entire figure consistent. This is understandable.

I want this:

enter image description here

Perhaps I need to make a 4 panel subplot with the fourth panel only containing the colorbar?

jwillis0720
  • 3,846
  • 8
  • 37
  • 66

1 Answers1

26

A way to go is indeed to create 4 axes, where the fourth axes will contain the colorbar. You can use the cbar_ax argument to tell the heatmap in which axes to plot the colorbar. In order to create the axes with some good proportions, you can use the gridspec_kw argument to subplots. The problem is then that the axes would share the y scaling with the colorbar, so we need to turn sharey off and manually share the first three axes by using ax1.get_shared_y_axes().join(ax2,ax3). This in turn will create unwanted axis labels, which need to be turned off.

import seaborn  as sns
import matplotlib.pyplot as plt
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
f,(ax1,ax2,ax3, axcb) = plt.subplots(1,4, 
            gridspec_kw={'width_ratios':[1,1,1,0.08]})
ax1.get_shared_y_axes().join(ax2,ax3)
g1 = sns.heatmap(flights,cmap="YlGnBu",cbar=False,ax=ax1)
g1.set_ylabel('')
g1.set_xlabel('')
g2 = sns.heatmap(flights,cmap="YlGnBu",cbar=False,ax=ax2)
g2.set_ylabel('')
g2.set_xlabel('')
g2.set_yticks([])
g3 = sns.heatmap(flights,cmap="YlGnBu",ax=ax3, cbar_ax=axcb)
g3.set_ylabel('')
g3.set_xlabel('')
g3.set_yticks([])

# may be needed to rotate the ticklabels correctly:
for ax in [g1,g2,g3]:
    tl = ax.get_xticklabels()
    ax.set_xticklabels(tl, rotation=90)
    tly = ax.get_yticklabels()
    ax.set_yticklabels(tly, rotation=0)

plt.show()

enter image description here

ImportanceOfBeingErnest
  • 289,005
  • 45
  • 571
  • 615