熊猫 DataFrame 和 Keras

新手上路,请多包涵

我正在尝试使用 Keras 在 Python 中执行情绪分析。为此,我需要对文本进行词嵌入。当我尝试将数据拟合到我的模型时出现问题:

 model_1 = Sequential()
model_1.add(Embedding(1000,32, input_length = X_train.shape[0]))
model_1.add(Flatten())
model_1.add(Dense(250, activation='relu'))
model_1.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

我的火车数据的形状是

(4834,)

并且是 Pandas 系列对象。当我尝试拟合我的模型并使用其他一些数据对其进行验证时,出现此错误:

 model_1.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=2, batch_size=64, verbose=2)

ValueError:检查模型输入时出错:预期 embedding_1_input 具有形状(无,4834)但得到具有形状的数组(4834、1)

如何重塑我的数据以使其适合 Keras?我一直在尝试使用 np.reshape 但我无法将 None 元素与该功能一起放置。

提前致谢

原文由 Gonzalo Donoso 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 423
2 个回答

None 是进入训练的预期行数,因此您无法定义它。 Keras 还需要一个 numpy 数组作为输入,而不是 pandas 数据框。首先使用 df.values 将 df 转换为 numpy 数组,然后执行 np.reshape((-1, 4834)) 。请注意,您应该使用 np.float32 。如果您在 GPU 上训练它,这一点很重要。

原文由 Dat Tran 发布,翻译遵循 CC BY-SA 3.0 许可协议

https://pypi.org/project/keras-pandas/

最简单的方法是让 keras_pandas 包使 pandas 数据框适合 keras。下面显示的代码是包文档中的一般示例。

 from keras import Model
from keras.layers import Dense

from keras_pandas.Automater import Automater
from keras_pandas.lib import load_titanic

observations = load_titanic()

# Transform the data set, using keras_pandas
categorical_vars = ['pclass', 'sex', 'survived']
numerical_vars = ['age', 'siblings_spouses_aboard', 'parents_children_aboard', 'fare']
text_vars = ['name']

auto = Automater(categorical_vars=categorical_vars, numerical_vars=numerical_vars, text_vars=text_vars,
 response_var='survived')
X, y = auto.fit_transform(observations)

# Start model with provided input nub
x = auto.input_nub

# Fill in your own hidden layers
x = Dense(32)(x)
x = Dense(32, activation='relu')(x)
x = Dense(32)(x)

# End model with provided output nub
x = auto.output_nub(x)

model = Model(inputs=auto.input_layers, outputs=x)
model.compile(optimizer='Adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train model
model.fit(X, y, epochs=4, validation_split=.2)

原文由 Pardhu 发布,翻译遵循 CC BY-SA 4.0 许可协议

推荐问题