深度学习之简单分类
简单二元分类
制造数据
from sklearn.model_selection import train_test_split
from sklearn import datasets
import matplotlib.pyplot as plt
from tensorflow import keras
X,y = datasets.make_blobs(n_samples=1000,random_state=8,centers=2)
plt.scatter(X[:,0],X[:,1],c=y)
plt.show()
构建模型并训练
model = keras.models.Sequential([
keras.layers.Dense(32,input_shape=X.shape[1:]),
keras.layers.Dense(1,activation=keras.activations.sigmoid)]
)
model.summary()
model.compile(loss = keras.losses.binary_crossentropy,optimizer = keras.optimizers.RMSprop(learning_rate=0.1),metrics = [keras.metrics.Accuracy()])
model.fit(X,y,validation_split=0.25,epochs = 20)
查看测试数据和预测数据
print(y[0:10])
y_pre = model.predict(X[0:10])
print(y_pre)
[0 1 1 0 0 1 0 1 1 1]
[[0. 1. 1. 0. 0. 1. 0. 1. 1. 1.]]
多分类
制造数据
from sklearn.model_selection import train_test_split
from sklearn import datasets
import matplotlib.pyplot as plt
from tensorflow import keras
X,y = datasets.make_blobs(n_samples=1000,random_state=8,centers=3)
plt.scatter(X[:,0],X[:,1],c=y)
plt.show()
构建模型并训练
model = keras.models.Sequential([
keras.layers.Dense(32,input_shape=X.shape[1:], activation='relu'),
keras.layers.Dense(3,activation=keras.activations.softmax)]
)
model.summary()
model.compile(loss = keras.losses.sparse_categorical_crossentropy,
optimizer = keras.optimizers.Adam(),
metrics=['accuracy'])
model.fit(X,y,validation_split=0.25,epochs = 20)
查看数据
print(y[0:10])
y_pre = model.predict(X[0:10])
import numpy as np
print(np.reshape(y_pre,[10,3]))
[1 2 1 1 1 2 1 2 1 1]
[[4.50088549e-03 9.95355964e-01 1.43211320e-04]
[3.52771860e-03 2.17666663e-03 9.94295657e-01]
[5.39137749e-04 9.99391794e-01 6.91057867e-05]
[3.10646836e-03 9.93669093e-01 3.22450022e-03]
[1.59081508e-04 9.99381661e-01 4.59307077e-04]
[3.76076205e-04 2.09796475e-03 9.97525990e-01]
[1.03477845e-02 9.88485038e-01 1.16714649e-03]
[8.82121618e-04 1.39025709e-04 9.98978853e-01]
[2.29390264e-02 9.75332916e-01 1.72808918e-03]
[4.69710241e-04 9.99316335e-01 2.13949577e-04]]
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。