更新:2024/12/06
【scikit-learn】モデルを保存・読み込みする方法について


はるか
Scikit-learnでモデルを保存する方法、知りたい?

ふゅか
もちろん!モデルを保存しておけば、後で簡単に使えるって聞いたけど、詳しく教えて!
目次
1. Scikit-learnでモデルを保存する方法
Scikit-learnは、Pythonで機械学習を行うためのライブラリとして非常に人気があります。モデルを作成し、訓練を終えた後、そのモデルを保存して後で再利用することがよくあります。この記事では、Scikit-learnでモデルを保存する方法をわかりやすく解説します。
1.1. なぜモデルを保存するのか?
モデルの保存は、時間と計算リソースの節約につながります。大規模なデータセットでモデルを訓練するには時間がかかりますが、一度保存しておけば、次回からは訓練済みのモデルをロードしてすぐに使うことができます。例えば、以下のような場面で役立ちます。
- 他のプロジェクトやシステムで同じモデルを使う
- WebアプリケーションやAPIで予測を行う
- モデルの再現性を確保する(同じ結果を何度でも得られるようにする)

はるか
モデルを保存すると、他プロジェクトで再利用も可能。

ふゅか
そうそう!あとAPIに組み込んで、いろんなアプリで活用できるのも魅力だよね!
2. モデル保存の方法
Scikit-learnでは、モデルを保存する方法として以下の2つを紹介します。
joblib
モジュールを使用する方法pickle
モジュールを使用する方法
それぞれの方法を具体的に見ていきましょう。
3. joblibモジュールを使う方法
3.1. 手順
- モジュールをインポートします。
- モデルを保存します。
- 保存したモデルを読み込みます。
3.2. ランダムフォレストを利用したコード例
以下は、ランダムフォレストのモデルを保存する例です。
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import joblib
# データセットを読み込み
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# モデルを作成・訓練
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
# モデルを保存
joblib.dump(model, 'random_forest_model.joblib')
print("モデルを保存しました。")
# 保存したモデルを読み込む
loaded_model = joblib.load('random_forest_model.joblib')
print("モデルを読み込みました。")
# 読み込んだモデルで予測
predictions = loaded_model.predict(X_test)
print("予測結果:", predictions)
3.3. ポイント
- 保存したモデルは、ファイル名(例:
random_forest_model.joblib
)として保存されます。 - 読み込む際は
joblib.load()
を使います。
4. pickleモジュールを使う方法
pickle
はPython標準ライブラリで、オブジェクトを保存できます。Scikit-learnのモデル保存にも対応しています。
4.1. 手順
- モジュールをインポートします。
- モデルを保存します。
- 保存したモデルを読み込みます。
4.2. サポートベクターマシンを利用したコード例
以下は、サポートベクターマシン(SVM)のモデルを保存する例です。
from sklearn.svm import SVC
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pickle
# データセットを読み込み
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# モデルを作成・訓練
model = SVC()
model.fit(X_train, y_train)
# モデルを保存
with open('svm_model.pkl', 'wb') as f:
pickle.dump(model, f)
print("モデルを保存しました。")
# 保存したモデルを読み込む
with open('svm_model.pkl', 'rb') as f:
loaded_model = pickle.load(f)
print("モデルを読み込みました。")
# 読み込んだモデルで予測
predictions = loaded_model.predict(X_test)
print("予測結果:", predictions)
4.3. ポイント
pickle.dump()
でモデルを保存します。pickle.load()
でモデルを読み込みます。- ファイル操作のために
open()
関数を使用します。