In this article, you will learn how to use scikit-learn to save and load your machine learning model in Python.
1. Tutorial overview
This tutorial is divided into 3 parts, they are:
- 1. Save your model with pickle
- 2. Use joblib to save your model
- 3. Tips for saving the model
Second, save the model with pickle
Pickle is the standard method of serializing objects in Python.
You can use the pickle operation to serialize your machine learning algorithm and save the serialized format to a file.
You can load this file later to deserialize your model and use it to make new predictions.
The following example demonstrates how to train a logistic regression model on the diabetes dataset of the Pima Indians, save the model to a file and load it to make predictions on the unseen test set ( from here).
# Save Model Using Pickle
import pandas
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
import pickle
url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv"
names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
dataframe = pandas.read_csv(url, names=names)
array = dataframe.values
X = array[:,0:8]
Y = array[:,8]
test_size = 0.33
seed = 7
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)
# Fit the model on training set
model = LogisticRegression()
model.fit(X_train, Y_train)
# save the model to disk
filename = 'finalized_model.pkl'
pickle.dump(model, open(filename, 'wb'))
# some time later...
# load the model from disk
loaded_model = pickle.load(open(filename, 'rb'))
result = loaded_model.score(X_test, Y_test)
print(result)
Running the example will save the model to finalized_model.pkl in the local working directory.
Load the saved model and evaluate it to estimate the accuracy of the model to unseen data.
0.755905511811
Three, save the model with joblib
Joblib is part of the SciPy ecosystem and provides utilities for pipelined Python jobs.
It provides utilities for saving and loading Python objects that effectively use NumPy data structures.
This is useful for machine learning algorithms (such as K-Nearest Neighbors) that require a large number of parameters or store the entire data set.
The following example demonstrates how to train a logistic regression model on the Pima Indians diabetes dataset, using joblib to save the model to a file and load it to make predictions on the invisible test set.
# Save Model Using joblib
import pandas
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
import joblib
url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv"
names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
dataframe = pandas.read_csv(url, names=names)
array = dataframe.values
X = array[:,0:8]
Y = array[:,8]
test_size = 0.33
seed = 7
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)
# Fit the model on training set
model = LogisticRegression()
model.fit(X_train, Y_train)
# save the model to disk
filename = 'finalized_model.pkl'
joblib.dump(model, filename)
# some time later...
# load the model from disk
loaded_model = joblib.load(filename)
result = loaded_model.score(X_test, Y_test)
print(result)
Running this example will save the model as a finalized_model.pkl
file and create a file (four additional files) for each NumPy array in the model.
After the model is loaded, an estimate of the accuracy of the model for unseen data will be reported.
0.755905511811
Summarize
In this article, you learned how to use scikit-learn to persist your machine learning algorithms in Python.
You learned two techniques you can use:
- A pickle API for serializing standard Python objects.
- The joblib API for efficiently serializing Python objects using NumPy arrays.
related articles:
Save and load machine learning models in Python
Use scikit-learn to save and load machine learning models in Python
Python - How to Save and Load ML Models
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。