PythonでNOTEARS・ベイジアンネットによる因果グラフ推定 -causalnexの紹介-
今回の記事は因果探索についてです。
因果探索は、データを与えることで、そのデータの変数間に潜む因果構造を推定しようという手法です。メジャーな手法としては、離散変数に対してベイジアンネットワークや、非ガウス連続変数に対してのLiNGAMなどの手法があります。
その中でも機械学習系のトップカンファレンスNeurIPSの2018年論文「DAGs with NO TEARS: Continuous Optimization from Structure Learning」で提案されたNOTEARSという手法について紹介したいです!NOTEARSは、実質損失関数を加えるのみで、今までの機械学習ライクなシンプルなアプローチで因果グラフを推定できる手法なので紹介してみます。
また、Python実装がNOTEARS実装がcausalnexでQuantumBlackから公開されているので使用方法も含め紹介していきます。
DAGと因果グラフ
DAGとは
ベイジアンネットやLiNGAMといった因果探索手法では、推定する因果グラフにDAG構造を想定することが多いです。NOTEARSも例外ではなく、推定される因果グラフにDAGを想定しています。
DAGとはDirect Acyclic Graph、日本語では有向非巡回グラフというもので、つまりは矢印がついていて巡回していないグラフを指します。
例を出してみます。下の左図は、有向ではあるものの、巡回している(どこか任意のノードから見て自分に返ってくるパスがある)ため、DAGではないです。右図に関しては、有向かつ巡回していないため、DAGの例となります。
因果グラフの意味合い
誤解を恐れずにいうと、因果グラフは、NOTEARSやLiNGAMの場合は、データ生成過程とみなすとわかりやすいかなと思います。例えば、上の右図の因果グラフで、$x_1$に関しては、$x_1 = f_{31}(x_3) + e_1$という風にデータが生成され、$x_2$に関しては、$x_2 = f_{12}(x_1) + f_{32} (x_3) + e_2$というように$x_1, x_3$を基にデータが生成されるとみなします。
では、因果グラフにDAGを仮定するのはどういうことかというと、もちろん制約として入れることで問題を解きやすくするというのもあるのでしょうが、データが巡回して生成されないということなので、自然な仮定なのかなと思われます。
NOTEARSとは
以下ではNOTEARSの基本的な考え方を紹介します。NO TEARS はNon-combinatorial Optimization via Trace Exponential and Augmented lagRangian for Structure learning の略です。
手法名めちゃめちゃ長いですが、NOTEARSの重要なポイントとしては、
の2点かなと思います。
そもそもDAGを、データの生成過程と捉えると、データ$X$は、下図のように
$X = WX$のように表すことが可能になります。この図では、$W$について一旦線形を仮定してます。また、$W$の対角成分はDAGの制約(巡回なし)から数値が入らないです。
NOTEARSではこの$W$についてDAGであることを定量的に表すことで、最適化問題の制約として、それを加えることで、因果グラフ(DAG)を推定するということを考えます。
結論をいうと、DAGであることの定量評価(DAG-ness)としては、
$$h(W) = \text{tr}\left( e^{W\circ W} \right) -d = 0$$
が提案されています。ここで$W$は$d\times d$の重み行列で、$\circ$はアダマール積(各要素を対応する要素で積をとるもの)になります。$W$がDAGである場合には、$h(W) = 0$となります。つまり、$h(W)$が0になるような$W$を推定できればそれはDAGであることになります。*1
最終的には、これを制約として加えて、$W$を以下の問題を推定することを考えます。
\begin{align}& \min_{W\in\mathbb{R}^{d\times d}} & \dfrac{1}{2n}\| X - WX \|_{F}^2 + \lambda \| W\|_1\\\ & \text{subject to } & \text{tr}\left( e^{W\circ W} \right)-d = 0\end{align}
この式は、単純に$X$の推定式である$WX$と$W$の2乗誤差(と$\ell_1-$penalty)となっており、制約として$W$がDAGであることが加わっています。単純にそれぞれの変数に対して他の変数からのみ(自身を含めずに)回帰する、かつ制約としてその前回の関係式はDAGであるという最適化問題になります。
これによって推定される$W$の意味合いとしては、そもそも$W$はデータ生成仮定を表す行列である上、DAGであることを制約として加えているため、この$W$が因果グラフとなります。
式自体は線形の式を紹介しましたが、もちろんこれは非線形にも$WX$を変えるだけで容易に対応できます。causalnexでもニューラルネットを用いて非線形にしたバージョンなどが実装されています。
Causalnexを用いてPythonでサンプルデータ実験
ここからcausalnexを使ってNOTEARSで因果グラフを、その推定結果を元にベイジアンネットを構築し、予測まで実行していきます。コードはこちらで公開しています。
サンプルデータ作成
実験用のサンプルデータを生成します。
サンプルデータとしては、以下の因果グラフを従う5変数($y, x_1, x_2, c_1, c_2$)を作成して、それだけを入力とした時に正しく因果グラフが推定できるかどうかを確認していきます。
def generate_sample_data(n_samples=1000):
"""
サンプルデータ生成
| c1 --> x1
| c2 --> x1
| c1 --> y
| c2 --> y
| x1 --> y
| x2 は独立
"""
np.random.seed(0)
c1 = np.random.normal(0, 0.5, size=n_samples)
c2 = np.random.choice(2, size=n_samples)
np.random.seed(1)
x1 = np.random.normal(1, 2, size=n_samples) + c1 + c2/10
np.random.seed(2)
y = np.random.uniform(-1,2, size=n_samples) + x1 + c1 + c2 /10
np.random.seed(3)
x2 = np.random.uniform(-2, 2, size=n_samples)
raw_data = pd.DataFrame({"x1": x1, "y":y, "x2": x2, "c1": c1, "c2":c2})
return raw_data
# データ作成
n_samples = 2000
struct_data = generate_sample_data(n_samples)
上の因果グラフに従う2000サンプルのデータを作成しました。
結果を表示するための関数を準備
causalnexでは、plotツールが準備されているのですが、pygraphvizなど環境をつくるのが面倒なインストールする必要があります。それを避けるために自作のプロット関数を作っておきます。
causalnexで推定した因果グラフはnetworkxを元に作成されているため、networkxのグラフから因果グラフをplotするutilを作りました。こちらも同様にこちらのリンクに入っています (util/networkx_util.py
)。
# -*- coding: utf-8 -*-
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
class NetworkxUtil:
def plot_structure_model(self, structure_model, layout_method="spring", layout_seed=1, figsize=(5,5),
node_shape="o", node_size=1500, node_color="#1abc9c", font_size=18,
edge_color="#34495e", min_edge_width=2, max_edge_width =3.5, arrowsize=25,
plot_edge_weights=True, edge_weights_color="#e74c3c", edge_weights_fontsize=12, alpha=0.8):
# plot settings
fig, ax = plt.subplots(figsize=figsize)
pos = self._load_layout(structure_model=structure_model, layout_method=layout_method, layout_seed=layout_seed)
# edgeの重み係数をplotするか否かのフラグ
if plot_edge_weights:
nx.draw_networkx_edge_labels(
structure_model, pos,
edge_labels={(u, v): round(d["weight"], 4) for (u,v,d) in structure_model.edges(data=True)},
font_color=edge_weights_color,
font_size = edge_weights_fontsize
)
# edgeの重みに応じて太さを変更、最低でもmin_edge_widthに設定
edge_width = [
np.min([np.max([d["weight"], min_edge_width]), max_edge_width] )
for (u, v, d) in structure_model.edges(data=True)
]
# networkxでの推定したモデルのplot
nx.draw_networkx(
structure_model,
ax = ax,
pos = pos,
node_shape = node_shape,
node_color = node_color,
node_size = node_size,
edge_color = edge_color,
width = edge_width,
alpha = alpha,
font_size = font_size,
arrowsize = arrowsize,
with_labels = True,
# connectionstyle="arc3, rad=0.1" # if curve
)
plt.axis("off")
fig.show()
def _load_layout(self, structure_model, layout_method, layout_seed=0):
"""layout_methodを指定して、pos: position keyed by nodeを取得"""
if layout_method == "circular":
pos = nx.circular_layout(structure_model)
elif layout_method == "spring":
pos = nx.spring_layout(structure_model, seed=layout_seed)
elif layout_method == "planar":
pos = nx.planar_layout(structure_model)
elif layout_method == "shell":
pos = nx.shell_layout(structure_model)
elif layout_method == "random":
pos = nx.random_layout(structure_model, seed=layout_seed)
else:
raise ValueError(f"Method {layout_method} is not expected.")
return pos
@staticmethod
def print_weights(structure_model, cutoff=0):
"""エッジとその重みをprint"""
for (u, v, d) in structure_model.edges(data=True):
if np.abs(d["weight"]) >= cutoff:
print(f"[ {u} ] ---> [ {v} ]\t\tWeight\t{d['weight']:.5f}")
NOTEARSで因果グラフ推定
では早速、NOTEARSを実行していきます。 NOTEARSはfrom_pandas
というmethodでデータを入力することで簡単に実行できます。
from causalnex.structure.notears import from_pandas
from util.networkx_util import NetworkxUtil
nx_util = NetworkxUtil() # networkxの図示, printのためのutil
# NOTEARSを実行, from_pandas_lassoでL1 penaltyをつけて推定することが可能
sm = from_pandas(
struct_data,
tabu_edges = [],
tabu_parent_nodes = None,
tabu_child_nodes = None,
)
sm.threshold_till_dag() # DAGになるように閾値をあげる
sm.remove_edges_below_threshold(0.2) # 係数の閾値を0.2にする
nx_util.plot_structure_model(structure_model=sm, layout_method="circular") # この条件で正しく推定できていることがわかる。
from_pandas
の結果のみからでは誤差の影響で完全なDAGにならないケースがあります。そのためthresholdをDAGになるまで上げる(threshold_till_dag
)や、小さすぎる重みを持つエッジを削除する(remove_edges_below_threshold
)などある程度操作する必要があります。
そうして出たプロット結果をみてみると、ただ正しく因果グラフが推定されていることがわかります。
また今回の$x2$のように他のノードとの関連がないノードが含まれてしまっている場合など、因果グラフをきれいにするため最大のサブグラフに限定して抽出することもできます。
# largest subgraphを抽出
sm = sm.get_largest_subgraph()
nx_util.plot_structure_model(sm, layout_method="random", figsize=(9,6))
# 係数をprint
nx_util.print_weights(sm)
# 出力
[ x1 ] ---> [ y ] Weight 1.05582
[ c1 ] ---> [ x1 ] Weight 0.93608
[ c1 ] ---> [ y ] Weight 0.91724
[ c2 ] ---> [ x1 ] Weight 1.05973
[ c2 ] ---> [ y ] Weight 0.49489
以上のようにNOTEARSによって正しく因果グラフ(ここでは、データ生成過程)を推定できることがわかりました。
NOTEARSでTabuを指定して因果グラフを推定
NOTEARSでは、あらかじめ「この矢印はないだろう」といった因果グラフに関しての人間の知見をtabu_edges
という形で取り込むことができます。
他にも、このノードは親ノードにはなり得ない、子ノードにはなり得ないといった知見を加えることができるので、実際にやってみて因果グラフを推定してみます。
# struct data
sm_with_tabu = from_pandas(
struct_data,
tabu_edges = [("c1", "x1")], # c1 --> x1 に線が引かれなくなる。
tabu_parent_nodes = ["c2"], # 親ノードにならなくなる。矢印元にならなくなる。
tabu_child_nodes = ["y"], # 子ノードにならなくなる。矢印が刺されなくなる。
)
sm_with_tabu.threshold_till_dag()
sm_with_tabu.remove_edges_below_threshold(0.2)
nx_util.plot_structure_model(structure_model=sm_with_tabu, layout_method="circular") #
nx_util.print_weights(sm_with_tabu) # 係数を print
プロット結果を見ると、指定した$c_1$から$y$までのパスはtabu_edges
を指定するまではあったのですが、消えていることがわかります。また他にもtabu_child_nodes
で指定した$y$が矢印がささっておらず子ノードになっていないこと、tabu_parents_nodes
で指定した$c2$から矢印が出ておらず親ノードになっていないことが確認できます。
以下ではtabuを指定せず、正しく因果グラフを推定できたグラフを元にベイジアンネットを構築します。
推定した因果グラフ構造を元にベイジアンネットを推定
causalnexではNOTEARSで推定した因果グラフを元にベイジアンネットワーク(pgmpy実装を内部で利用している)を推定することができます。
ただし、ベイジアンネットワークは離散化した変数に対して適用する必要があるため、NOTEARSを推定したデータそのものではなく、適宜離散化する必要がありますcausalnexではDiscretiser
という離散化のための便利ツールが用意されているので、今回はこれを使っていきます。
from causalnex.discretiser import Discretiser
# ベイジアンネットを推定するために離散化データを作成する。
discretised_data = struct_data.copy().drop(["x2"], axis=1)
# quantileでnum_buckets数に分割。
discretised_data["y"] = Discretiser(method="quantile", num_buckets=5).fit_transform(discretised_data["y"].values)
# ほぼ同数num_buckets数に分割。
discretised_data["x1"] = Discretiser(method="uniform", num_buckets=5).fit_transform(discretised_data["x1"].values)
# percentile_split_pointsを指定して分割。それぞれ<=が含まれる集合になる
discretised_data["c1"] = Discretiser(method="percentiles", percentile_split_points=[0.25, 0.5, 0.75]).fit_transform(discretised_data["c1"].values)
# split_pointsで分割。それぞれ<=が含まれる集合になる
discretised_data["c2"] = Discretiser(method="fixed", numeric_split_points=[0,1]).transform(discretised_data["c2"].values)
NOTEARSで推定した因果グラフと、離散化したデータを用いてベイジアンネットワークを推定します。 推定したfit_cpds
によって条件付き確率を計算することができます。計算した後は、対象の子ノードを指定してbn.cpds["y"]
といった形で条件付き確率表を出力することができます。
from causalnex.network import BayesianNetwork
# load estimated causal graph
bn = BayesianNetwork(sm)
bn = bn.fit_node_states(discretised_data) # 取りうるすべての状態を指定、 本来testデータに新出のカテゴリがある場合には判定し、変換する必要がある
bn = bn.fit_cpds(discretised_data, method="BayesianEstimator", bayes_prior="K2")
ベイジアンネットでの推論
推定したベイジアンネットワークでは、ラベルの予測(predict
), そのラベルへの予測所属確率(predict_probability
)を推定することができます。今回予測対象を$y$にしますが、推定されたデータ生成過程では$x2$などは$y$には不要なので、必要なカラムのみをスライスして予測してみます。
pred_target_col = "y"
parents_columns = bn.cpds[pred_target_col].columns.names # 予測対象カラムの親ノードのみ必要であるためsliceする、sliceせずとも実行可能
y_pred = bn.predict(discretised_data[parents_columns], pred_target_col) # ラベル予測
y_predict_proba = bn.predict_probability(discretised_data[parents_columns], pred_target_col) # 予測確率
# 予測結果
print("True Label :\t", discretised_data["y"].values[:5])
print("Pred Label :\t", np.ravel(y_pred)[:5])
print("Pred Proba [y=4] :\t", y_predict_proba["y_4"].values[:5])
# 出力結果
True Label : [4 0 2 2 4]
Pred Label : [4 1 3 1 4]
Pred Proba [y=4] : [0.88235294 0.01298701 0.30075188 0.01639344 0.85074627]
予測結果がでてきますね。こうして、NOTEARSで推定した因果グラフを元に、ベイジアンネットワークを推定し、予測まで行うことができました。
フルのコードはこちらにおいています。
まとめ
ベイジアンネットやLiNGAMなどの因果探索手法は理論面が難しいなーと思っていたのですがNOTEARSのように実質損失関数を追加するだけで、因果グラフを候補を出すことができるのは、説明もしやすいしとても便利ですね。また、tabu_edge
などで人の知識を取り入れることができるので、より納得感のあるモデルを作りやすそうですね。
人の知識をいれられるようなモデルは説明しやすい面白いので、今後も積極的にキャッチアップできればと思います。causalnex内でもまだまだ色んな機能があるので、次回紹介できればと思います。
今回は詳細を端折ってしまいましたが、そもそも因果・相関って何が違うのといった話や因果推論のイメージについてはこちらの本がわかりやすいので超おすすめです。
他にも同じ因果探索手法であるLiNGAMについて、LiNGAMの発明者である清水先生の本がとても行間が少なくわかりやすい本なのでオススメです。