215

Using Matplotlib, I want to plot a 2D heat map. My data is an n-by-n Numpy array, each with a value between 0 and 1. So for the (i, j) element of this array, I want to plot a square at the (i, j) coordinate in my heat map, whose color is proportional to the element's value in the array.

How can I do this?

Karnivaurus
  • 20,685
  • 51
  • 137
  • 232
  • 6
    did you look at the [`matplotlib` gallery](http://matplotlib.org/gallery.html#images_contours_and_fields) at all before posting? There are some good examples using `imshow`, `pcolor` and `pcolormesh` that do what you want – tmdavison Oct 22 '15 at 13:57
  • Possible duplicate of [multi colored Heat Map error Python](http://stackoverflow.com/questions/30068049/multi-colored-heat-map-error-python) – jkalden Oct 23 '15 at 06:52

6 Answers6

278

The imshow() function with parameters interpolation='nearest' and cmap='hot' should do what you want.

import matplotlib.pyplot as plt
import numpy as np

a = np.random.random((16, 16))
plt.imshow(a, cmap='hot', interpolation='nearest')
plt.show()

A sample color map produced by the example code

Henry Ecker
  • 31,792
  • 14
  • 29
  • 50
P. Camilleri
  • 11,576
  • 6
  • 39
  • 68
  • 2
    I don't think specifying interpolation is necessary. – miguel.martin Mar 28 '17 at 06:12
  • 3
    @miguel.martin as per pyplot's doc: "If interpolation is None (its default value), default to rc image.interpolation". So I think it is necessary to include it. – P. Camilleri Mar 28 '17 at 07:11
  • 1
    @P.Camilleri How to scale the X and Y axes? (Change only the numbers, no zoom). – Dole Jul 17 '19 at 18:10
  • 1
    Link broken, [new link](https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html?highlight=heatmap) – mins Apr 20 '21 at 12:52
103

Seaborn takes care of a lot of the manual work and automatically plots a gradient at the side of the chart etc.

import numpy as np
import seaborn as sns
import matplotlib.pylab as plt

uniform_data = np.random.rand(10, 12)
ax = sns.heatmap(uniform_data, linewidth=0.5)
plt.show()

enter image description here

Or, you can even plot upper / lower left / right triangles of square matrices, for example a correlation matrix which is square and is symmetric, so plotting all values would be redundant anyway.

corr = np.corrcoef(np.random.randn(10, 200))
mask = np.zeros_like(corr)
mask[np.triu_indices_from(mask)] = True
with sns.axes_style("white"):
    ax = sns.heatmap(corr, mask=mask, vmax=.3, square=True,  cmap="YlGnBu")
    plt.show()

enter image description here

user
  • 4,960
  • 7
  • 46
  • 70
PyRsquared
  • 6,033
  • 9
  • 44
  • 77
  • 1
    I'm very fond of the plot type, and the half matrix is useful. Two questions: 1) in the first plot the little squares are separated by white lines, could they be joint? 2) the white line width seem to vary, is this an artefact? – P. Camilleri Apr 28 '18 at 22:23
  • 1
    You can use the ‘linewidth’ argument I used in the first plot for any other plot (in the second plot for example), to get spaced out squares. The line widths only appear to vary in the first plot due to screen shot issues, they don’t actually vary in reality, they should stay at the constant you set them. – PyRsquared Apr 29 '18 at 00:44
  • 1
    while this is true - i don't think that a response using seaborn should be considered full for a question that specifically states matplotlib. – baxx May 01 '20 at 16:24
57

I would use matplotlib's pcolor/pcolormesh function since it allows nonuniform spacing of the data.

Example taken from matplotlib:

import matplotlib.pyplot as plt
import numpy as np

# generate 2 2d grids for the x & y bounds
y, x = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100))

z = (1 - x / 2. + x ** 5 + y ** 3) * np.exp(-x ** 2 - y ** 2)
# x and y are bounds, so z should be the value *inside* those bounds.
# Therefore, remove the last value from the z array.
z = z[:-1, :-1]
z_min, z_max = -np.abs(z).max(), np.abs(z).max()

fig, ax = plt.subplots()

c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('pcolormesh')
# set the limits of the plot to the limits of the data
ax.axis([x.min(), x.max(), y.min(), y.max()])
fig.colorbar(c, ax=ax)

plt.show()

pcolormesh plot output

Erasmus Cedernaes
  • 1,487
  • 19
  • 14
39

For a 2d numpy array, simply use imshow() may help you:

import matplotlib.pyplot as plt
import numpy as np


def heatmap2d(arr: np.ndarray):
    plt.imshow(arr, cmap='viridis')
    plt.colorbar()
    plt.show()


test_array = np.arange(100 * 100).reshape(100, 100)
heatmap2d(test_array)

The heatmap of the example code

This code produces a continuous heatmap.

You can choose another built-in colormap from here.

huangbiubiu
  • 1,051
  • 17
  • 36
14

Here's how to do it from a csv:

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

# Load data from CSV
dat = np.genfromtxt('dat.xyz', delimiter=' ',skip_header=0)
X_dat = dat[:,0]
Y_dat = dat[:,1]
Z_dat = dat[:,2]

# Convert from pandas dataframes to numpy arrays
X, Y, Z, = np.array([]), np.array([]), np.array([])
for i in range(len(X_dat)):
        X = np.append(X, X_dat[i])
        Y = np.append(Y, Y_dat[i])
        Z = np.append(Z, Z_dat[i])

# create x-y points to be used in heatmap
xi = np.linspace(X.min(), X.max(), 1000)
yi = np.linspace(Y.min(), Y.max(), 1000)

# Interpolate for plotting
zi = griddata((X, Y), Z, (xi[None,:], yi[:,None]), method='cubic')

# I control the range of my colorbar by removing data 
# outside of my range of interest
zmin = 3
zmax = 12
zi[(zi<zmin) | (zi>zmax)] = None

# Create the contour plot
CS = plt.contourf(xi, yi, zi, 15, cmap=plt.cm.rainbow,
                  vmax=zmax, vmin=zmin)
plt.colorbar()  
plt.show()

where dat.xyz is in the form

x1 y1 z1
x2 y2 z2
...
kilojoules
  • 8,814
  • 17
  • 70
  • 133
  • 1
    Just a short heads up: I had to change the method from cubic to either nearest or linear because the cubic resulted in a lot of NaNs since I'm working with rather small values between 0..1 – Maikefer Feb 09 '18 at 18:08
2

Use matshow() which is a wrapper around imshow to set useful defaults for displaying a matrix.

a = np.diag(range(15))
plt.matshow(a)

enter image description here

https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.matshow.html

This is just a convenience function wrapping imshow to set useful defaults for displaying a matrix. In particular:

  • Set origin='upper'.
  • Set interpolation='nearest'.
  • Set aspect='equal'.
  • Ticks are placed to the left and above.
  • Ticks are formatted to show integer indices.
mamaj
  • 751
  • 6
  • 8