Everything you need to know about Seaborn for Data Visualization
Seaborn is a powerful Python Data Visualization library that is built on top of Matplotlib. It offers a plethora of statistical graphs for almost every type of data. You must be wondering why to learn Seaborn while Matplotlib already works just fine. Two major reasons for that:
- Seaborn has the same graphs as Matplotlib yet it offers a lot more customization options, and the graphs plotted using Seaborn look more aesthetically pleasing and sophisticated.
- To plot a few graphs in Matplotlib, you have to write some serious convoluted Python codes which Seaborn makes a little easier to write and eases the learning curve.
Note: Covering all the parameters in every plot function is beyond the scope of this article. You might want to refer to the official documentation for the same.
Before we directly dive into exploring various plot functions and graphs, we have to make sure that we understand the hierarchy first. Seaborn graphs are broadly divided into six categories:
- Relational Plot
- Distribution Plot
- Categorical Plot
- Regression Plot
- Matrix Plot
- Residual Plot
- Multi Plot
Seaborn has respective figure-level and axes-level functions for each kind of graph. We’ll cover each of them as we progress in the article. Although both figure-level functions and axes-level functions work very similarly, Seaborn advises its users to use the figure-level functions more often.
Importing libraries and loading datasets
For the scope of this article, we’ll use two datasets that are already baked in the Seaborn library and another one from Plotly. Now, we’ll import all the required libraries first.
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
Note: Here, we’ve imported Plotly just to use a dataset that is baked into it. We won’t cover anything from Plotly in this article.
If you’re wondering which datasets are available to use inside the Seaborn library, here’s a quick code snippet that’ll fetch all names of the datasets present inside the Seaborn library.
sns.get_dataset_names()
We’ll use three famous datasets — tips and iris dataset from the Seaborn, and gapminder from Plotly.
tips = sns.load_dataset('tips')
iris = sns.load_dataset('iris')
gap = px.data.gapminder()
We’ll quickly have a look at the head of our datasets to have a rough idea of their features that’ll help us understand which graph to plot for univariate analysis, bivariate analysis, and so on.
Tips dataset: It seems to be from a restaurant. The dataset recorded the total bill amount, tip amount, sex of the customer, if s/he is a smoker, which day and time it was, and the number of people the customer along with.
Iris Dataset: The Iris flower dataset has four features (0:4) describing the various lengths and widths. And the last feature describes the species of that particular flower.
Gapminder Dataset: First three features of the dataset is pretty straightforward. The rest of the features denote life expectancy, population, GDP per capita, ISO Country acronym, and ISO Country Code respectively.
**ISO = International Organization for Standardization
Now that we’re done with all the prerequisites, let’s dive straight into plotting graphs with Seaborn. Shall we?
Relational Plot
Relational Plots are mostly plotted for bivariate analysis. By plotting these graphs, we try to understand the statistical relation between two or more features, i.e. how two or more features in our datasets are related.
Two plots fall under the category of Relational Plot.
- Scatter Plot
- Line Plot
Figure-level Function for Relational Plots is sns.relplot()
and the respective axes-level function for Scatterplot and Lineplot are sns.scatterplot()
and sns.lineplot()
.
Scatter Plot
Although Scatterplots are generally used for bivariate analysis, 3D Scatterplots can also come in handy for multivariate analysis. Here, we’ll use the tips dataset to understand Scatterplot further.
#Figure-level function
sns.relplot(data=tips,x='total_bill',y='tip',kind='scatter')
plt.title('Total Bill vs Tip');
#Axes-level function
g = sns.scatterplot(data=tips,x='tip',y='total_bill')
g.set_title('Total bill vs Tip');
Note: Graph plotted using Figure-level function has a squarish outlook while Axes-level function often produces graphs that are rectangular in shape after the default auto-adjustments.
Line Plot
A line plot just draws a line joining all the points in a scatterplot. Line plots are often used to understand the relation between the variable on X and Y axes. Line plots are particularly useful for visualizing how one variable changes in response to changes in another variable.
#Filter the gapminder dataset for only India as country
india = gap[(gap['country'] == 'India')]
#Figure-level function
sns.relplot(data=india, y='lifeExp',x='year',kind='line')
plt.title('YoY lifeExp of India');
#Filter the gapminder dataset for only India as country
india = gap[(gap['country'] == 'India')]
#Axes-level function
g = sns.lineplot(data=india, y='lifeExp',x='year')
g.set_title('YoY lifeExp of India');
Distribution Plot
As the name suggests, the distribution plots are primarily used to understand the distribution of a feature of a dataset for univariate analysis, though there are certain use cases of distribution plots like 2D histograms for bivariate analysis. However, some variations and techniques allow distribution plots to be used for categorical data as well.
There’re three types of distribution plots:
- Histogram
- KDE Plot
- Rug Plot
Figure-level Function for Distribution Plots is sns.displot()
and the respective axes-level function for Histogram, KDE Plot, and Rug Plot are sns.histplot()
, sns.kdeplot()
and sns.rugplot().
Histogram
Histograms are mostly used in distribution plots. It’s mostly used in univariate analysis. Later in this section, we’ll learn to plot KDE Plot, Histogram together, and KDE Plot and Rug Plot together.
#Figure-level function
sns.displot(data=tips,x='tip',kind='hist');
#Axes-level function
sns.histplot(data=tips,x='tip',hue='sex');
Note: Here, I’ve added a parameter called
hue
and set that tosex
column to have an understanding of the distribution oftip
column on the basis of the customer’s sex. It’s available in both Figure-level and Axes-level functions.
KDE Plot
The acronym for KDE is Kernal Density Estimation. A KDE Plot is used to estimate the probability density function of a continuous variable. It provides a smoothed representation of the distribution of the data, which can be especially useful when dealing with continuous data that might not be well-represented by discrete histogram bins.
#Figure-level function
sns.displot(data=tips,x='tip',kind='kde',hue='sex');
#Axes-level function
sns.kdeplot(data=tips,x='tip',hue='sex');
Note: KDE Plots are sometimes plotted with a histogram to take advantage of both graphs within a single figure.
sns.displot(data=tips,x='tip',kind='hist',kde=True);
Rug Plot
A rug plot is a distribution plot often used in conjunction with other plots, such as histograms or KDE plots, to display individual data points (small tick marks, resembling a “rug”) along a single axis. A rug plot is primarily used to give a visual representation of where individual data points lie on the axis, providing insight into the density and distribution of the data.
Note: Although you can plot a rug plot using
sns.rugplot()
the function, it doesn’t make sense of plotting a stand-alone rug plot. Hence, we’ll plot it on top of a KDE plot.
sns.displot(data=tips,x='tip',kind='kde',rug=True);
Note: You have to the Figure-level function to plot both plots on top of each other. A stand-alone rug plot can’t be plotted using
sns.displot()
either. For that, you’ll have to usesns.rugplot()
.
Categorical Plot
Categorical plots in Seaborn allow us to visualize the distribution of data across categorical variables. These plots are particularly useful when we want to explore the relationship between one categorical variable and one or more numerical variables — which is called bivariate analysis.
Seaborn has six types of categorical plots:
- Bar Plot
- Count Plot
- Box Plot
- Violin Plot
- Swarm Plot
- Point Plot
Bar Plot
Bar Plot is a type of data visualization that presents data in the form of rectangular bars, where the length or height of each bar corresponds to the value of the category it represents.
#Axes-level Function
g = sns.barplot(data=tips,x='day',y='total_bill',estimator=np.median)
g.set_title('Daywise Total Bill');
#Figure-level
sns.catplot(kind = 'bar',data=tips,x='day',y='total_bill',estimator=np.median,errorbar=None)
plt.title('Daywise Total bill');
Note: By default the value of the
estimator
parameter is set asmean
. To hide the whisker, we tweaked theerrorbar
asNone
.
Count Plot
A count plot displays the count or frequency of occurrences of categorical data. It is particularly useful for understanding the distribution of categorical variables and comparing their frequencies.
#Axes-level Function
g = sns.countplot(data=tips,x='day')
g.set_title('Daywise Number of Customers');
#Figure-level Function
sns.catplot(kind='count',data=tips,x='day')
plt.title('Daywise Number of Customers');
Point Plot
A point plot is used to display the central tendency (typically the mean) and variability of a continuous variable across different categories or groups. It is similar to a bar plot or line plot but focuses on individual data points representing summary statistics. It uses the whiskers that we saw earlier in Bar Plot.
#Axes-level Function
g = sns.pointplot(data=tips,x='day',y='tip')
g.set_title('Point Plot on Daywise Tips');
#Figure-level Function
sns.catplot(data=tips,x='day',y='tip',hue='sex',kind='point')
plt.title('Point Plot on Daywise Tips');
Box Plot
A box plot, also known as a box-and-whisker plot, is a data visualization technique that displays the distribution and summary statistics of a dataset. It provides a visual summary of the central tendency, spread, and potential outliers within the data. Box plots are particularly useful for comparing multiple plots side by side based on different parameters.
The most important feature of Box Plot is the five-number summary of a dataset: minimum, first quartile (Q1), median (second quartile, Q2), third quartile (Q3), and maximum. It’s mostly helpful to detect outliers.
#Figure-level Function
sns.catplot(kind='box',data=tips,y='tip')
plt.title('Distribution of Tips');
#Axes-level Function
g = sns.boxplot(data=tips,y='tip')
g.set_title('Distribution of Tips');
Violin Plot
A violin plot combines aspects of a box plot and a kernel density plot. It is used to plot the distribution of a continuous or categorical variable across different categories or groups. Violin plots provide insights into the data’s distribution, including its central tendency, spread, and multimodal nature, while also showing the summary statistics found in box plots.
#Axes-level Function
g = sns.violinplot(data=tips,x='day',y='tip')
g.set_title('Daywise Spread of Tip');
#Violin Plot using Figure level function
sns.catplot(data=tips,y='tip',x='day',kind='violin');
plt.title('Spread of Tip Column');
Swarm Plot
A swarm plot is used to display the distribution of data points for a continuous or categorical variable. It is often used in scenarios where we have a relatively small number of data points and want to avoid overlapping points unlike in traditional scatter plots.
#Axes-level Function
g = sns.swarmplot(data=tips,x='day',y='tip',hue='sex')
g.set_title('Daywise Tip');
#Figure-level Function
sns.catplot(kind='swarm',data=tips,x='day',y='tip',hue='sex')
plt.title('Daywise Tip based on sex');
Strip Plot
A strip plot is used to display individual data points along a single axis. It is particularly useful for visualizing the distribution of a continuous variable within different categories or groups. Although strip plots and swarm plots are similar in nature and purpose, they have some differences in terms of how they arrange and display individual data points.
#Figure-level Function
sns.catplot(kind='strip',data=tips,x='day',y='tip',hue='sex')
plt.title('Strip Plot on Daywise Tips Count');
#Axes-level Function
g = sns.stripplot(data=tips,x='day',y='tip',hue='sex')
g.set_title('Strip Plot on Daywise Tips Count');
Regression Plot
A regression plot, also known as a regression line or scatter plot with a fitted line, shows the relationship between two continuous variables. It is particularly used to visualize and understand the linear relationship between the independent and dependent variables.
#Axes-level Function
g = sns.regplot(data=tips,x='total_bill',y='tip')
g.set_title('Regression Plot plotted b/w Total Bill vs Tips');
#Figure-level Function
sns.lmplot(data=tips,x='total_bill',y='tip')
plt.title('Regression Plot plotted b/w Total Bill vs Tips');
Matrix Plot
A matrix plot represents the relationships between multiple variables in the form of a grid of colored cells. It is primarily used for displaying patterns and trends within a dataset, especially when dealing with a large amount of data or when we’re trying to identify correlations between variables.
There’re two types of Matrix Plot:
- Heatmap
- Clustermap
Heatmap
A heatmap is used to display a matrix-like representation of data, where values in a rectangular grid are represented using colors. Heatmaps are particularly effective for representing the relationships between two categorical variables or for showing patterns and trends in large datasets.
#Heat map
temp = gap.pivot_table(index='country',columns='year',values='lifeExp')
plt.figure(figsize=(15,20)) #Adjusting figure size using axes level function
sns.heatmap(temp);
Cluster Map
A clustermap is a specific kind of heatmap that combines a heatmap with hierarchical clustering. It is a powerful tool for exploring patterns and relationships in complex datasets by organizing both rows and columns based on their similarities. Clustermaps are particularly useful when we’re trying to reveal hidden structures or groupings within your data.
#Cluster map
sns.clustermap(iris.iloc[:,[0,1,2,3]]);
Residual Plot
A residual plot is used to assess the goodness of fit and assumptions of a regression model. It displays the differences (residuals) between the observed values and the predicted values from the regression model. Residual plots are particularly valuable for evaluating whether the model’s assumptions are met and for identifying any patterns or outliers in the residuals, which can provide insights into the model’s accuracy and potential improvements.
#Axes level Function
sns.residplot(data=tips,x='tip',y='total_bill');
Multi Plot
In Seaborn, a multi-plot is a collection of multiple individual plots that are plotted together within a single figure. Multi-plots are used for presenting different aspects of the data and for comparing multiple visualizations side by side. Seaborn has various functions and tools to create multi-plots, allowing us to arrange multiple plots within a single figure.
There’re two types of Multi Plot:
- Joint Plot
- Pair Plot
Joint Plot
A joint plot combines two different types of plots to display the relationship between two continuous variables. It provides a way to visualize the joint distribution of the variables and includes individual univariate distributions for each variable along the axes.
### Joinplot
sns.jointplot(data=tips,x='tip',y='total_bill',kind='scatter',hue='sex');
Note: Although Join Grid also plots join plots, it offers a plethora of customizations that
sns.joinplot()
doesn’t have.
### Jointgrid
g = sns.JointGrid(data=tips,x='tip',y='total_bill',hue='sex')
g.plot(sns.scatterplot,sns.kdeplot);
Pair Plot
A pair plot displays pairwise relationships and distributions between multiple variables in a dataset. It creates a grid of scatter plots and histograms, allowing us to quickly visualize the interactions between pairs of variables. Pair plots are primarily used for exploring potential patterns and correlations within multivariate datasets.
### Pairplot
sns.pairplot(iris,hue='species');
Note: Although Pair Grid also plots pair plots, it offers a plethora of customizations that
sns.pairplot()
doesn’t have.
### PairGrid
g = sns.PairGrid(data=iris,hue='species')
g.map_diag(sns.histplot)
g.map_upper(sns.scatterplot)
g.map_lower(sns.lineplot);
Facet Grid
A Facet Grid is a function in Seaborn that allows us to create a grid of subplots, each displaying a different subset of your data based on one or more categorical variables. It’s particularly useful for visualizing relationships and distributions within different categories or groups of your data.
### Facet Grid
g = sns.FacetGrid(tips, row='sex',col='time')
g.map(sns.scatterplot, data=tips,x='tip',y='total_bill');
Note: We can use a facet grid (using
row
andcol
parameters) with almost all seaborn data visualization plots using their respective figure-level function. We can also usecol_wrap
parameter to limit the number of columns in each row.
sns.relplot(data=tips,x='total_bill',y='tip',hue='sex',col='day',col_wrap=2);
That’s all for this article, I hope I’ve managed to cover the majority of the concepts in Seaborn. The idea behind writing this article was to make it comprehensive yet very beginner friendly. I tried to maintain that inside out. I hope you find this article helpful in the course of learning Data Visualization in Python. If you find something I might have missed, please feel free to mention it in the comments.
Feel free to connect with me on LinkedIn.