Streamlit入門の8回目です。matplotlibはグラフを描画するライブラリです。いろいろなグラフを描画することができます。さまざまな設定項目があり、かなりの自由度でカスタマイズが可能で、出版レベルのグラフを描画することができます。ここではmatplotlibやそれを使ったライブラリにおいて、グラフを描画するための基本的な内容について説明します。
この記事はFuture Coders独自教材からの抜粋です。
Streamlit入門 – 8)Matplotlibの基礎 目次
いろいろなAPI
matplotliubはグラフを描画するためのライブラリです。
Matplotlib — Visualization with Python
グラフを描画するときに、直接matplotlibの命令を呼び出して描画することもできますが、Pandas(DataFrame)やSearborn、あるいはそれらの組み合わせで描画することも可能です。
以下はmatplotlibを使って折れ線グラフを描画する例です。
import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
plt.plot(x, y)
plt.show()
以下はpandasを使って折れ線グラフを描画する例です。データフレームを作り、そのplotメソッドを呼び出します。
import matplotlib.pyplot as plt
import pandas as pd
df = pd.DataFrame({"x":[1, 2, 3, 4, 5], "y":[5, 4, 3, 8, 7]})
df.plot()
plt.show()
以下はseaborn(とpandas)を使って折れ線グラフを描画する例です。データフレームを作り、seabornの関数lineplotに渡します。
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
df = pd.DataFrame({"x":[1, 2, 3, 4, 5], "y":[5, 4, 3, 8, 7]})
sns.lineplot(df)
plt.show()
結果は以下の通りです。
matplotlibで描画した場合、引数xとyで表される座標を順番に線で結んでいます((1,5)→(2,4)→(3,3)…)。一方、pandasやseabornはリストのインデックスがX軸となり、リストの値がY軸の値として描画されていることが分かります。このようにライブラリによって、引数の解釈や描画に対するアプローチに違いがある場合もありますが、よく似たグラフが描画されていることが分かります。
FigureとAxesSubplot
matplotlibは「1つの図に1つ以上のグラフを描画できるよう」設計されています。以下左図は1つの図に1つのグラフを、右図は1つの図に2つのグラフを描画する様子を示しています。
ここで図全体とグラフは以下のようなオブジェクトとして実装されています。
- Figure = 図全体
- AxesSubplot = 個々のグラフ
このように図とグラフで役割が分担されていることを意識すると良いでしょう。図全体のタイトルをつける場合にはFigureオブジェクトのメソッド、個々のグラフにタイトルをつける場合にはAxesSubplotオブジェクトのメソッドを使用するという使い方になります。
以下は、図全体のタイトルをFigureオブジェクトのsuptitleメソッドで、左図のタイトルをaxes[0]オブジェクト(AxesSubplot型)のset_titleメソッドで、右図のタイトルをaxes[1]オブジェクト(AxesSubplot型)のset_titleで設定している様子を表しています。
グラフの目盛りを間隔を変えたり、単位を変えたりするにはAxesSubplotオブジェクトのメソッドを使うことになります。
FigureとAxesSubplotの取得方法
それでは、このFigureやAxesSubplotをどのように取得するか見てゆきましょう。
matplotlibの関数で指定する場合
- plt.subplot() = 1つのFigureで1つのグラフを描画する場合
- plt.subplots(n行, m列) = 1つのFigureをn行xm列の区域に分割する場合
subplotとsubplots, 間違えやすいので注意してください。
subplotを使った場合は戻り値としてAxesSubplotオブジェクトが得られます。
ax = plt.subplot() # axはAxesSubplotオブジェクト
このオブジェクトのfigureプロパティからFigureオブジェクトへの参照が取得できます。
一方、subplotsを使った場合、戻り値はFigureとAxesSubplotsの2つとなります。
fig, ax = plt.subplots(2, 3) # figはFigureオブジェクト, axはAxesSubplotオブジェクト
ここで、戻り値のaxには注意が必要です。
- 1行1列の時は、AxesSubplotオブジェクトが戻る
- 1行n列、もしくはm行1列のときは、AxesSubplotオブジェクトのリストが戻る
- n行m列のときは、AxesSubplotオブジェクトのリストのリストが戻る
ax軸の複数形がaxesです。なので複数の図を扱う場合には変数名をaxesとすることが多いようです。このaxesから個々要素を取得する場合 axes[行][列] と記述することもできますが、より直感的な記述方法 axes[行, 列] という書き方もサポートされています。
実際にsubplotとsubplotsの戻り値を確認してみましょう。
import matplotlib.pyplot as plt
ax = plt.subplot()
print(f"ax={ax} fig={ax.figure}")
fig, ax = plt.subplots(1,1)
print(f"1. ax={ax}") # axはAxesSubplotオブジェクト
fig, axes = plt.subplots(1,3) # 3つの要素のリスト
print(f"2. len(axes)={len(axes)} axes[0]={axes[0]}")
fig, axes = plt.subplots(3,1) # 3つの要素のリスト
print(f"3. len(axes)={len(axes)} axes[0]={axes[0]}")
fig, axes = plt.subplots(4,3) # 4行3列のリストのリスト
print(f"4. len(axes)={len(axes)} axes[0,0]={axes[0,0]}")
以下は上記プログラムの出力です。
ax=AxesSubplot(0.125,0.11;0.775x0.77) fig=Figure(640x480)
1. ax=AxesSubplot(0.125,0.11;0.775x0.77)
2. len(axes)=3 axes[0]=AxesSubplot(0.125,0.11;0.227941x0.77)
3. len(axes)=3 axes[0]=AxesSubplot(0.125,0.653529;0.775x0.226471)
4. len(axes)=4 axes[0,0]=AxesSubplot(0.125,0.712609;0.227941x0.167391)
DataFrameを使って描画する場合
DataFrameのメソッドを使ってグラフを描画した場合、その戻り値はAxesSubplotとなります。そのオブジェクトのfigureプロパティからFigureオブジェクトへの参照を取得することができます。plt.subplot() と同じような使い方です。
ax = df.plot() # axはAxesSubplotオブジェクト
fig = ax.figure # figはFigureオブジェクト
データとグラフの関係
どのAPIを使用するにしても(matplotlib, pandas, seaborn…)、描画するグラフの種類によって必要な情報が変わって来ます。逆に言うと、データの種類によって適したグラフが変わってきます。
例えば、以下のようなイメージです。
- 折れ線グラフ = データの値がどのように変化しているか線で描画します。そのため、数値のリストが必要になります。
例)顧客数の変化 = [19, 23, 25, 23, 31, …] - ヒストグラム = どの値が何回出現するかを度数で描画するため、数値のリストが必要になります。
例)身長の分布 = [187, 268, 155, 175, 156, 178, 176, …] - 散布図 = パラメタ間の相関を描画するための図なので、X軸を何にするか、Y軸を何にするか、を指定する必要があります。
- X軸:気温 = [25℃, 30℃, 22℃, , .]
- Y軸:売上 = [15万円, 25万円, 12万円, .]
「なんらかのデータが与えられたときに、データを表現するにはどんなグラフが適しているのか、そのためにはどのような引数で指定する必要があるのか」このようなことを意識しておくと良いでしょう。
DataFrameオブジェクトによるグラフ描画
データフレームオブジェクトのメソッドを使うとmatplotlibの詳細を意識せずに簡単にグラフを描画することができます。ただし、細かい設定が必要な場合はmatplotlibを使うか、seabornを使うと良いでしょう。まずは、DataFrameオブジェクトを使ってグラフを描画してみましょう。
折れ線グラフ
数値のリストが与えられたときに、順番にその値をプロットして、その値を線でつなぎます。よって、数値のリストが1つあれば1本の線を描画することができます。以下はstreamlitを使わずに単に線を描画するサンプルです。
import pandas as pd
import matplotlib.pyplot as plt
vals = [x*x for x in range(-50, 60, 10)]
df = pd.DataFrame({"v":vals})
df.plot()
plt.show()
valsが数値のリストです。リスト内包表記で作成していますが、実際には以下のような値となります。
[2500, 1600, 900, 400, 100, 0, 100, 400, 900, 1600, 2500]
X軸が要素のインデックス(0から始まる順番)、Y軸がリストの数値です。この値をプロットして線で結んでいます。
これをstreamlit対応にするには以下のように修正します。
import streamlit as st
import pandas as pd
vals = [x*x for x in range(-50, 60, 10)]
df = pd.DataFrame({"v":vals})
ax = df.plot()
st.pyplot(ax.figure)
データフレームのplotメソッドは戻り値としてAxesSubplot型のオブジェクトを返します。これは各グラフを表すオブジェクトです。そのfigureプロパティから、図全体のFigureオブジェクトを取得できます。それをst.pyplotに渡すことでStreamlitでグラフを描画することができます。
st.pyplot(df.plot().figure)
と記述しても同じ結果となります。
棒グラフ
棒グラフは数値に応じた長さの箱を描画するグラフです。
import streamlit as st
import pandas as pd
x = [100, 50, 60, 70, 80]
y = [80, 60, 90, 100, 80]
df = pd.DataFrame({"x": x, "y":y})
col1, col2 = st.columns(2)
with col1:
st.pyplot(df.plot.bar().figure)
with col2:
st.pyplot(df.plot.bar(stacked=True).figure)
st.columns(2)で2列のレイアウトとし、左に通常の棒グラフ、右に積み上げ棒グラフを描画しています。stacked=Trueを引数で指定すると積み上げグラフになります。
barの代わりにbarhとすると横方向の棒グラフになります。
ヒストグラム
ヒストグラムは度数分布といって、どの範囲にデータが何個含まれるかを棒グラフ上に描画するグラフです。
import streamlit as st
import pandas as pd
from random import gauss
col1, col2 = st.columns(2)
with col1:
mu = st.slider("平均", min_value=0, max_value=100, value=50)
with col2:
sigma = st.slider("標準偏差", min_value=1, max_value=30, value=10)
score = [gauss(mu, sigma) for _ in range(1000)]
df = pd.DataFrame({"score": score})
ax = df.plot.hist(bins=30)
ax.set_xlim(0, 200)
st.pyplot(ax.figure)
gauss関数は平均muと標準偏差sigmaを使って乱数を生成します。平均50、標準偏差10とすると、いわゆる偏差値の分布となります。これらの数値はst.sliderから取得しています。
今回はgauss関数を使って、1000個のランダムな数を作成しています。ヒストグラムはdf.plot.hist関数で描画します。引数のbinsはいくつの範囲に分けるかを指定します。
今回平均が移動する様子をみたかったので、x軸の範囲を0~200に固定しました。データフレームでグラフを描画したときの戻り値はAxeSubplotオブジェクトです。これは各グラフを表します。そのset_xlimメソッドで範囲を明示的に設定し、図全体をax.figureで取得して、st.pyplot関数に渡すことでグラフを描画しています。
散布図
散布図は2つの変数の相関具合を把握するためのグラフです。よって、X軸とY軸の列名を指定する必要があります。
import streamlit as st
import seaborn as sns
df = sns.load_dataset("tips")
ax = df.plot.scatter(x="total_bill", y="tip",
c="size", colormap="cool")
st.pyplot(ax.figure)
scatterでは、引数にX軸、Y軸の列名を指定します。さらに、c引数を使って、描画する点の色の種類を、色のスタイルをcolormap引数で指定することも可能です。
Seabornを使ったグラフ描画
ここまで、データフレームからグラフ描画用のメソッドを使ってグラフを描画する方法をみてきました。描画に必要なデータはデータフレームに含まれています。一方、seabornを使って描画するときは必要なデータは引数として指定する必要があります。
両方とも似たようなグラフを描画できますが、考え方が異なるために結果が違ってくることもあります。一般的にseabornを使った方がいろいろ細かい設定が可能です。
ちなみに、現時点でsnsでサポートされているグラフの種類は26個あるようです。
[m for m in dir(sns) if m.endswith("plot")]
出力は以下の通りです。
['barplot', 'boxenplot', 'boxplot', 'catplot', 'countplot', 'displot', 'distplot', 'dogplot', 'ecdfplot', 'histplot', 'jointplot', 'kdeplot', 'lineplot', 'lmplot', 'miscplot', 'pairplot', 'palplot', 'pointplot', 'regplot', 'relplot', 'residplot', 'rugplot', 'scatterplot', 'stripplot', 'swarmplot', 'violinplot']
折れ線グラフ
DataFrameを作成して、sns.lineplot関数で描画します。DataFrameから描画したときと同じ結果となります。
import streamlit as st
import pandas as pd
import seaborn as sns
vals = [x*x for x in range(-50, 60, 10)]
df = pd.DataFrame({"v":vals})
ax = sns.lineplot(df)
st.pyplot(ax.figure)
棒グラフ
データフレームのdf.plot.barは数値に応じた長さの箱を描画しています。一方、seabornの棒グラフbarplotは「あるデータの集合の平均値を箱の長さで描画する」というように思想が異なります。
左がseabornのグラフで右がデータフレームのグラフです。
import streamlit as st
import pandas as pd
import seaborn as sns
x = [80, 60, 70, 90, 50]
y = [10, 20, 30, 40, 50]
df = pd.DataFrame({"x": x, "y":y})
col1, col2 = st.columns(2)
with col1:
st.pyplot(sns.barplot(df).figure)
with col2:
st.pyplot(df.plot.bar().figure)
同じデータを与えても思想が異なるので、違った結果となっていることが分かります。seabornの棒の上にある黒い線は信頼区間を表します。
ヒストグラム
ヒストグラムの描画にはsns.hitplotを使用します。
import streamlit as st
import pandas as pd
import seaborn as sns
from random import gauss
col1, col2 = st.columns(2)
with col1:
mu = st.slider("平均", min_value=0, max_value=100, value=50)
with col2:
sigma = st.slider("標準偏差", min_value=1, max_value=30, value=10)
score = [gauss(mu, sigma) for _ in range(1000)]
df = pd.DataFrame({"x":score})
ax = sns.histplot(df)
ax.set_xlim(0, 200)
st.pyplot(ax.figure)
使い方はDataFrameの場合と殆ど同じです。
散布図
散布図はsns.scatterplotで描画します。グラフはDataFrameのときとほぼ同じです。
import streamlit as st
import seaborn as sns
df = sns.load_dataset("tips")
ax = sns.scatterplot(df, x="total_bill", y="tip")
st.pyplot(ax.figure)
seabornでは相関をみるためのグラフが充実しています。
- jointplot = 各パラメタの分布をグラフの上と右に描画
- pairplot = 複数の列の組み合わせを行列形式で描画
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt
st.set_page_config(layout="wide")
df = sns.load_dataset("iris")
col1, col2 = st.columns(2)
with col1:
st.subheader("jointplot")
st.write("2つのパラメタの相関を詳しくみるときに使用します")
ax = sns.jointplot(data=df, x="sepal_width", y="sepal_length")
st.pyplot(ax)
with col2:
st.subheader("pairplot")
st.write("複数のパラメタの相関をみるときに使用します")
ax = sns.pairplot(data=df)
st.pyplot(ax)