Wrappers for the Scikit-Learn API

您可以通过keras.wrappers.scikit_learn.py的包装器,将Sequential Keras模型(仅单输入)用作Scikit-Learn工作流程的keras.wrappers.scikit_learn.py .

有两个包装器可用:

keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params) ,它实现了Scikit-Learn分类器接口,

keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params) ,它实现了Scikit-Learn回归器接口.

Arguments

  • build_fn :可调用函数或类实例
  • sk_params :模型参数和拟合参数

build_fn应该构造,编译并返回build_fn模型,然后将其用于拟合/预测. 以下三个值之一可以传递给build_fn

  1. 功能
  2. 实现__call__方法的类的实例
  3. 没有. 这意味着您实现了一个从KerasClassifierKerasRegressor继承的类. 当前类的__call__方法将被视为默认的build_fn .

sk_params接受模型参数和拟合参数. 法律模型参数是build_fn的参数. 请注意,像scikit-learn中的所有其他估计器一样, build_fn应该为其参数提供默认值,以便您可以创建估计器而无需将任何值传递给sk_params .

sk_params还可以接受的参数调用fitpredictpredict_probascore方法(例如, epochsbatch_size ). 拟合(预测)参数按以下顺序选择:

  1. 传递给fitpredictpredict_probascore方法的字典参数的值
  2. 传递给sk_params
  3. keras.models.Sequential fitpredictpredict_probascore方法的默认值

使用scikit-learn的grid_search API时,合法的可调参数是可以传递给sk_params参数,包括拟合参数. 换句话说,您可以使用grid_search搜索最佳的batch_sizeepochs以及模型参数.