Saturday 17 August 2024

Web Tutorial: Python Matplotlib Heatmap (Part 2/2)

Here comes the fun - or at least, the visual - part. We work on the heatMap() function.

We start off by setting the size of the plot.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))


Then we feed in the values, vals, to the lmshow() method. It's a two-dimensional list, remember.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals)


Finish off with a call to the show() method, which displays the heatmap.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals)

  plt.show()


This is the default view. You can't make much sense of it just yet, because the labels aren't clear.


Now, use the xticks() method. If we use numpy's arange() method and pass in 0 and the length of players, we'll get an array of values from 0 to the length of players, minus 1. Pass this into the call to xticks(). This will change nothing, because that's already the default setting.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals)
  plt.xticks(np.arange(0, len(players)))

  plt.show()


But if you pass in players as an argument after that...
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals)
  plt.xticks(np.arange(0, len(players)), players)

  plt.show()


...you see the x-axis is populated with values! However, the names can be a little long (looking right at you, Alex Oxlade-Chamberlain!) so let's modify the code a bit.


Add a rotation value as an argument.
plt.xticks(np.arange(0, len(players)), players, rotation=90)


Looking good!


We do something similar for the yticks() method, passing in seasons for the labels.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals)
  plt.xticks(np.arange(0, len(players)), players, rotation=90)
  plt.yticks(np.arange(0, len(seasons)), seasons)

  plt.show()


And now our heatmap takes shape. We have labels, now we just need to know what those colors mean.


Let's use the colorbar() method to facilitate this...
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals)
  plt.xticks(np.arange(0, len(players)), players, rotation=90)
  plt.yticks(np.arange(0, len(seasons)), seasons)

  plt.colorbar()

  plt.show()


Now you can see, with the bar on the right appearing, that yellow is a higher number and deep blue is 0.


Let's go for a better color palette. It's Liverpool, so pass in "Reds" as the cmap argument.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals, cmap="Reds")
  plt.xticks(np.arange(0, len(players)), players, rotation=90)
  plt.yticks(np.arange(0, len(seasons)), seasons)

  plt.colorbar()

  plt.show()


Now you can see the color scheme has changed, along with the color bar!


Let's do something simple and add a title using the title() method and stat parameter value.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals, cmap="Reds")
  plt.xticks(np.arange(0, len(players)), players, rotation=90)
  plt.yticks(np.arange(0, len(seasons)), seasons)

  plt.colorbar()

  plt.title("Liverpool FC Player " + stat)
  plt.show()


There it is... a small piece of the puzzle but adds so much clarity.


Next, we will add labels to the colored squares. It's quite straightforward - first we need a nested For loop top traverse through the two-dimensional list that is vals. For both the inner and outer For loop, we need the index value, thus we will have to use enumerate(), first on vals, then on each list within vals.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals, cmap="Reds")
  plt.xticks(np.arange(0, len(players)), players, rotation=90)
  plt.yticks(np.arange(0, len(seasons)), seasons)

  plt.colorbar()
  
  for colindex, lst in enumerate(vals):
    for rowindex, val in enumerate(lst):


  plt.title("Liverpool FC Player " + stat)
  plt.show()


Then in the inner loop, we use the index values (rowindex and colindex) and the value itself, val, in the text() method.
for colindex, lst in enumerate(vals):
  for rowindex, val in enumerate(lst):
    plt.text(rowindex, colindex, val)



Pretty straightforward.... though you may have noticed that the default color is black and that Mohamed Salah's goal talley for the 2017/2018 season here is almost invisible due to the lack of color contrast.


Let's fix that! Just before the nested For loop, declare vals_avg and set the value using the nanmean() method, passing in vals as an argument. Thus we obtain the average value of the entire dataset in the heatmap.
def heatMap(seasons, players, vals, stat):
  plt.figure(figsize = (10, 10))

  plt.imshow(vals, cmap="Reds")
  plt.xticks(np.arange(0, len(players)), players, rotation=90)
  plt.yticks(np.arange(0, len(seasons)), seasons)

  plt.colorbar()
  
  vals_avg = np.nanmean(vals)
  for colindex, lst in enumerate(vals):
    for rowindex, val in enumerate(lst):
      plt.text(rowindex, colindex, val)

  plt.title("Liverpool FC Player " + stat)
  plt.show()


Then in the inner For loop, declare rgb as 0. If the value of val is greater than vals_avg, set it to 1. So now we have a situation where, if val is below average, rgb is 0, and otherwise it is 1.
vals_avg = np.nanmean(vals)
for colindex, lst in enumerate(vals):
  for rowindex, val in enumerate(lst):
    rgb = 0
    if (val > vals_avg): rgb = 1

    plt.text(rowindex, colindex, val)


And then we modify the call to the text() method to include a color argument.
vals_avg = np.nanmean(vals)
for colindex, lst in enumerate(vals):
  for rowindex, val in enumerate(lst):
    rgb = 0
    if (val > vals_avg): rgb = 1
    plt.text(rowindex, colindex, val, color=())


And we will use the value of rgb in that argument! This means that we will have white if val is above average, and black if below.
vals_avg = np.nanmean(vals)
for colindex, lst in enumerate(vals):
  for rowindex, val in enumerate(lst):
    rgb = 0
    if (val > vals_avg): rgb = 1
    plt.text(rowindex, colindex, val, color=(rgb, rgb, rgb))


See what we've done? Now in cases where the square color is dark (which means the value is above average), the text is white. And in cases where the square color is white or light (which means the value is below average), the text is black!


Well, this was fun...

Heatmaps are a nice visualization tool provided you have a somewhat even spread of number of possible values along two dimensions.

Warm regards,
T___T

No comments:

Post a Comment