plotly入門

Python Plotly入門 - 複数のグラフをプロット

2021年6月19日

さて今回は、複数のサブグラフを一つのグラフとして表示する方法について解説します。

matplotlibなどでもサブグラフを作成することはもちろんできますが、plotlyはインタラクティブなグラフを作成するので、複数のグラフを作成する際も軸の共有を行うことでインタラクティブな複数グラフになります

最終的には以下のようなグラフを作成します。

subplots間の間隔を調整する方法など細かい設定も紹介したいと思います。

では、見ていきましょう。

複数グラフの作成方法

以下のようなデータフレームがあるとしましょう。

Date	     スノーピーク    楽天    ぐるなび	 ANA
2020-01-06     1045.84	    920     970.54	3450.50
2020-01-07     1067.69	    934	    1006.30	3499.25
2020-01-08     1042.86	    918	    993.39	3462.20
2020-01-09     1085.57	    924	    1005.31	3515.82
2020-01-10     1090.53	    920	    991.40	3498.27
...               ...	    ...	     ...	          ...

複数のグラフを一つにまとめるにはplotly.subplotsmake_subplotsを使います。

では、まずmake_subplotsとgraph_objectsをインポートしましょう。

import plotly.graph_objects as go
from plotly.subplots import make_subplots

複数のグラフもadd_traceを使うと非常にシンプルです。

コードは、以下のようになります。

fig = make_subplots(rows=2, cols=2)
fig.add_trace(go.Scatter(x=df['Date'],
                         y=df['スノーピーク'],
                         mode='lines',
                         name='スノーピーク',
                        ),
                         row=1,
                         col=1,
              )

fig.add_trace(go.Scatter(x=df['Date'],
                         y=df['楽天'],
                         mode='lines',
                         name='楽天',
                        ),
               row=1,
               col=2,
               )

fig.add_trace(go.Scatter(x=df['Date'],
                         y=df['ぐるなび'],
                         mode='lines',
                         name='ぐるなび',
                        ),
               row=2,
               col=1,
               )

fig.add_trace(go.Scatter(x=df['Date'],
                         y=df['ANA'],
                         mode='lines',
                         name='ANA',
                        ),
               row=2,
               col=2,
               )

まず、1行目でfig=make_subplots(...)でグラフ領域を作成しています。

この際に、行数と列数をそれぞれrowsとcolsで指定します。

あとは、fig.add_traceで個別のサブプロットの設定をしていきます。

線グラフなのでgo.Scatterを第一引数として渡しています。

そして、colとrowでどの位置に設定するか?を指定します。

これにより出来上がるグラフは以下のようになります。

では、少し肉付けしていきましょう。

サブグラフごとのタイトルを設定する

サブグラフごとにタイトルを設定するには、最初のmake_subplotsでsubplot_titlesに対してリストを渡してやります。

fig = make_subplots(rows=2, cols=2, 
       subplot_titles=['スノーピーク', '楽天', 'ぐるなび', 'ANA'])

あとは、update_layoutで全体のタイトルを設定し、凡例はタイトルに含まれているのでshowlegend=Falseとして消してしまいましょう。

fig.update_layout(title='4社の株価推移',
                  showlegend=False)

これにより以下のようになります。

軸を共有する

株価であれば水準が銘柄によって違うのでよいですが、リターンをプロットする場合、どのサブプロットにも同じメモリの軸があるとすっきりしません。

そういった場合は軸を共有します。

また、plotlyはインタラクティブなグラフなので、軸の共有もインタラクティブになります

軸を共有するにはmake_subplotsの引数でx軸、y軸それぞれshared_xaxesshared_yaxesを設定します。

デフォルトはFalseで、Trueを設定することでx軸は縦方向に共有、y軸は横方向に共有します。

True、False以外にも以下のような設定が可能です。

  • all
    すべてのサブプットで軸を共有
  • row
    行のサブプロットで軸を共有
  • col
    列方向のサブプロットで軸を共有

コードは以下になります。とりあえずallを設定し、すべてのサブプットで軸を共有します。

window = 20
companies = ['スノーピーク', '楽天', 'ぐるなび', 'ANA']
fig = make_subplots(rows=2, cols=2, 
                    subplot_titles=companies,
                    shared_yaxes='all',
                    shared_xaxes='all')
for i, company in enumerate(companies):
  row, col = divmod(i, 2)
  row +=1
  col += 1
  fig.add_trace(go.Scatter(x=df['Date'],
                          y=df[f'{company}_リターン'].rolling(window).mean(),
                          mode='lines',
                          name=company,
                          ),
                          row=row,
                          col=col,
                )

fig.update_layout(title='4社の株価リターン推移',
                  showlegend=False,)

すると以下のようなインタラクティブな軸の共有が行われます。

x軸もy軸も共有されていますね。

軸ラベルを設定する

軸ラベルを設定するには、軸ラベルを設定したいサブプロットを指定し、設定する必要があります。

そして、fig.update_xaxesfig.update_yaxestitleという引数に軸ラベルを設定します。

例えば、以下の例では、y軸については1行目1列のサブプロットと2行目1列のサブプロットをrow, colで指定しています。

x軸も同様でfig.update_xaxesで軸ラベルを設定しています。

yaxis_title = 'リターン(%)'
fig.update_yaxes(title=yaxis_title, row=1, col=1)
fig.update_yaxes(title=yaxis_title, row=2, col=1)

xaxis_title = 'Date'
fig.update_xaxes(title=xaxis_title, row=2, col=1)
fig.update_xaxes(title=xaxis_title, row=2, col=2)

サブプロット間の間隔を調整する

サブプロット間の間隔を調整したい場合は、make_subplotsの引数でhorizontal_spacing(横の余白)もしくはvertical_spacing(縦の余白)を設定します。

デフォルトはそれぞれ0.2と0.3です。

以下のように設定します。

make_subplots(...,
              horizontal_spacing=0.15,
              vertical_spacing=0.15,
              ...
              )

サブプロットのサイズを変更する

複数のサブプロットを結合する方法です。

例えば、以下のように2行2列のサブプロットで、2行目のサブプロットは一つに結合したいような場合です。

この場合は、make_subplots()のspecsという引数で設定します。

specsはリスト型の変数を取ります。

2x2のサブプロットであれば、[[・, ・], [・, ・]]というような2×2のリストを渡します。

リスト内の1つ目のリストが1行目のサブプロットを表し、2つ目のリストが2行目のサブプロットを表します。

そして、各リストの値は辞書型変数もしくはNoneを取ります

辞書型変数に{"colspan": 2}とすることで、その位置にあるサブプロットは2列使います。

colspanではなく"rowspan": 2と指定すると2行使います。

また、辞書型変数を{}という形で何も設定しなければ、特に何も変わりません。

Noneを指定するとそこには何も表示されません。

ででは、実際のコードを見てみましょう。

window = 20
companies = ['スノーピーク', '楽天', 'ぐるなび', 'ANA']
fig = make_subplots(rows=3, cols=2, 
                    subplot_titles=companies + ['月ごとのリターン平均'],
                    shared_yaxes='all',
                    shared_xaxes='all',
                    vertical_spacing=0.15,
                    specs=[[{}, {}],
                           [{}, {}], 
                          [{"colspan": 2}, None]]
                    )
for i, company in enumerate(companies):
  row, col = divmod(i, 2)
  row +=1
  col += 1
  fig.add_trace(go.Scatter(x=df['Date'],
                          y=df[f'{company}_リターン'].rolling(window).mean()*100,
                          mode='lines',
                          name=company,
                          ),
                          row=row,
                          col=col,
                )
fig.update_traces(hovertemplate='Date: %{x} <br>リターン: %{y:0.1f}%')

fig.add_trace(go.Bar(x=[f'{month}月' for month in df_mean['month']],
                     y=df_mean['average']*100,
                     width=0.5,
                     textposition='outside',
                     texttemplate='%{y:0.1f}%',
                     hovertemplate='Month: %{x} <br>リターン: %{y:0.1f}%',
                     ),
                     row=3,
                     col=1,
               )


# Update yaxis properties
yaxis_title = 'リターン(%)'
fig.update_yaxes(title=yaxis_title, row=1, col=1)
fig.update_yaxes(title=yaxis_title, row=2, col=1)


fig.update_layout(title='株価リターン推移と月ごとのリターンの平均',
                  showlegend=False,
                  width=800,
                  height=700)

3行3列のサブプロットを作成しています。

8行目、9行目、10行目がサブプットのサイズを設定している箇所です。

specs=[[{}, {}, {}], [{}, {}, {}], [{"colspan": 3}, None, None]]とすることで、1行目と2行目の1列目・2列目・3列目のサブプットはそのまま、3行目の1列目のサブプットは3列使い、2列目、3列目は1列目に吸収されるのでNoneとしています。

他にも多少の肉付けはしていますが、このようなグラフが出来上がります。

もちろんもっと複雑なサブプロットにも拡張できますので、必要に応じて変えていただければと思います。

まとめ

今回は、1つのグラフに複数のサブプロットを描画する方法を見てきました。

サブプロット間でもインタラクティブに軸を共有でき、分析には非常に便利ですので、皆さまにも是非使っていただければと思います。

では!

-plotly入門
-, ,