Whether you're learning Python just for fun or have a career as a data scientist or business analyst, it's a good idea to familiarize yourself with matplotlib.
It is a Python module that allows us to visualize data, helping us understand and analyze it to meet the business's goals.
One of the best things about matplotlib is that it is quite beginner-friendly. However, you may have trouble understanding how it works without simple examples, especially if you're new to programming.
But you have nothing to worry about because, in this brief post, we discuss five tips for mastering matplotlib along with code samples.
#1 Changing Plot Size
It might surprise you that this is the most searched question about matplotlib on Stackoverflow. It's important for Python programmers to understand how to change the size of a plot because the plot size must be set according to the volume and complexity of the data.
You can use the matplotlib module's figure() class to change a plot's size. It's as simple as passing the height and width of the plot in inches. You must pass these values as the figsize argument.
Let's see the figure() class in action:
import numpy as np import matplotlib.pyplot as plt import random # Plotting the plot fig=plt.figure(figsize=(5,8)) # Random function X=list(range(10)) Y=[x+(x*random.random()) for x in X] plt.plot(X,Y) plt.title('Line Plot') plt.xlabel('x-axis') plt.ylabel('y-axis') plt.show()
Running this code, you will notice that the same line looks different when presented in plots of different dimensions. Also, note that the code above uses a random function.
#2 Grouping Bar Plots
Bar plots enable convenient comparison of data across categories. Additionally, bar plots are an excellent visualization tool for comparing the data of different groups. But it can be difficult for programmers new to Python to create these plots correctly.
You can group two bar plots by stacking one over the other. You must carefully set the width and position of the bars on axes.
Here's some sample code visualizing the temperatures of various cities in two seasons:
temp_summer=[ random.uniform(20,40) for i in range(5)] temp_winter=[ random.uniform(0,10) for i in range(5)] fig=plt.figure(figsize=(10,6)) city=['City A','City B','City C','City D','City E'] x_pos_summer=list(range(1,6)) x_pos_winter=[ i+0.4 for i in x_pos_summer] graph_summer=plt.bar(x_pos_summer, temp_summer,color='tomato',label='Summer',width=0.4) graph_winter=plt.bar(x_pos_winter, temp_winter,color='dodgerblue',label='Winter',width=0.4) plt.xticks([i+0.2 for i in x_pos_summer],city) plt.title('City Temperature') plt.ylabel('Temperature ($^\circ$C)') #Annotating graphs for summer_bar,winter_bar,ts,tw in zip(graph_summer,graph_winter,temp_summer,temp_winter): plt.text(summer_bar.get_x() + summer_bar.get_width()/2.0,summer_bar.get_height(),'%.2f$^\circ$C'%ts,ha='center',va='bottom') plt.text(winter_bar.get_x() + winter_bar.get_width()/2.0,winter_bar.get_height(),'%.2f$^\circ$C'%tw,ha='center',va='bottom') plt.legend() plt.show()
#3 Saving a Plot as An Image
We've discussed two attractive-looking plots so far. However, we haven't yet discussed how to preserve the plots we create.
It's easy to save your plot as an image. You must use the savefig() method, which saves the current figure. There are many ways to use this method and customize the plot.
You will find the details of all the available parameters of the savefig() method in the official matplotlib documentation.
Note that you should never run savefig() after show()since the show() method creates a new figure. So, running savefig() after show() will result in the former saving a blank image.
#4 Annotating Plots
Annotations are comments added to specific points in a plot to help the reader understand the data. These comments are an excellent tool to mark positions in a plot.
You can annotate a plot using either the text() method or pyplot's annotate() method.
Let's first discuss the text() method. Suppose you want to write values over every bar in a bar plot. You can do this by passing the text you want to write at the x and y positions as arguments to the text() method.
Here's a code sample that you can run to see this in action:
fig=plt.figure(figsize=(6,6)) temp=[ random.uniform(20,40) for i in range(5)] city=['City A','City B','City C','City D','City E'] y_pos=list(range(1,6)) graph=plt.bar(y_pos, temp,color='violet') plt.xticks(y_pos,city) plt.title('City Temperature') plt.xlabel('Cities') plt.ylabel('Temperature ($^\circ$C)') for bar,t in zip(graph,temp): plt.text(bar.get_x() + bar.get_width()/2.0,bar.get_height(),'%.2f $^\circ$C'%t,ha='center',va='bottom') plt.show()
You can use the annotate() method to draw with a marker and write comments at any point in the plot. Let's see it in action:
fig=plt.figure(figsize=(8,6)) plt.plot(X,np.exp(X)) plt.title('Annotating Exponential Plot using plt.annotate()') plt.xlabel('x-axis') plt.ylabel('y-axis') plt.annotate('Point 1',xy=(6,400),arrowprops=dict(arrowstyle='->'),xytext=(4,600)) plt.annotate('Point 2',xy=(7,1150),arrowprops=dict(arrowstyle='->',connectionstyle='arc3,rad=-.2'),xytext=(4.5,2000)) plt.annotate('Point 3',xy=(8,3000),arrowprops=dict(arrowstyle='-|>',connectionstyle='angle,angleA=90,angleB=0'),xytext=(8.5,2200))
Run the code above to see how the annotate() method creates arrows. This feature makes it superior to the text() method since it can only be used to write text.
#5 Generating Subplots
A figure with multiple plots in it is called a subplot. They can be created in three ways. You can use either the subplot(), the subplots(), or the subplot2grid() methods.
The easiest way to create a subplot is to use the subplot() method. You must pass three values as arguments: the number of rows and columns and the index value. The method returns an axes object that can be used for plotting.
Let's see subplot() in action:
fig=plt.figure(figsize=(16,6)) ax1=plt.subplot(1,2,1) ax1.plot(X,[x+(x*random.random()) for x in X]) ax1.set_title('Plot 1') ax2=plt.subplot(1,2,2) ax2.plot(X,[x-(x*random.random()) for x in X]) ax2.set_title('Plot 2') plt.show()
If you run this code, you will see how smoothly subplot() creates a subplot. However, the plots in the subplot demand writing the same statement several times. This won't work if you create several subplots.
So, in use cases with several subplots, you can use the subplots() method, which returns an array of axes objects instead of a single axes object. You must pass the number of rows and columns as arguments to the subplots() method.
fig,axes=plt.subplots(2,2) print(axes) fig,axes=plt.subplots(2,2,figsize=(16,8)) axes[0][0].plot(X,[x-(x*random.random()) for x in X]) axes[0][0].set_title('Plot 1') axes[0][1].plot(X,[x+(x*random.random()) for x in X]) axes[0][1].set_title('Plot 2') axes[1][0].plot(X,[x-(x*random.random()) for x in X]) axes[1][0].set_title('Plot 3') axes[1][1].plot(X,[x+(x*random.random()) for x in X]) axes[1][1].set_title('Plot 4') plt.show()
After running this code, you might wonder whether you can modify the plots since they are of the same size. And this is where the subplot2grid() method comes in.
The method allows you to treat your subplot as a grid and modify the plots using the grid's rows and columns. You can pass the shape, location, rowspan, and columnspan as arguments. The method returns an axes object of a specific size and for a specific location in a grid.
Here's how you can use the method:
fig=plt.figure(figsize=(8,8)) ax1=plt.subplot2grid((4,4),(0,0),colspan=3) ax1.plot(X,[x+(x*random.random()) for x in X]) ax1.set_title('Plot 1 : (0,0)') ax2=plt.subplot2grid((4,4),(0,3)) ax2.plot(X,[x-(x*random.random()) for x in X]) ax2.set_title('Plot 2 : (0,3)') ax3=plt.subplot2grid((4,4),(1,0),rowspan=3,colspan=3) ax3.plot(X,[x-(x*random.random()) for x in X]) ax3.set_title('Plot 3 : (1,0)') ax4=plt.subplot2grid((4,4),(1,3),rowspan=3,colspan=1) ax4.plot(X,[x+(x*random.random()) for x in X]) ax4.set_title('Plot 4 : (1,3)') fig.tight_layout() plt.show()