<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          用Spark-Scala訓(xùn)練LightGBM模型

          共 21281字,需瀏覽 43分鐘

           ·

          2021-07-19 16:59

            今日表情 ?? 

          Spark-scala 可以使用LightGBM模型,既可以進(jìn)行分布式訓(xùn)練,也可以進(jìn)行分布式預(yù)測(cè),支持各種參數(shù)設(shè)置。
          支持模型保存,并且保存后的模型和Python等語(yǔ)言是可以相互調(diào)用的。
          需要注意的是,Spark-scala訓(xùn)練LightGBM模型時(shí), 輸入模型的訓(xùn)練數(shù)據(jù)集需要處理成一個(gè)DataFrame,用spark.ml.feature.VectorAssembler將多列特征轉(zhuǎn)換成一個(gè) features向量列,label作為另外一列。

          一,環(huán)境配置

          spark-scala要使用lightgbm模型,pom文件中要配置如下依賴。
          <dependency>
          <groupId>org.apache.spark</groupId>
          <artifactId>spark-mllib_${scala.version}</artifactId>
          <version>${spark.version}</version>
          <!--spark-ml要去掉pmml-model依賴-->
          <exclusions>
              <exclusion>
                  <groupId>org.jpmml</groupId>
                  <artifactId>pmml-model</artifactId>
              </exclusion>
          </exclusions>
          </dependency>

          <dependency>
              <groupId>org.jpmml</groupId>
              <artifactId>jpmml-sparkml</artifactId>
              <version>1.3.4</version>
          </dependency>
          <dependency>
              <groupId>org.jpmml</groupId>
              <artifactId>jpmml-lightgbm</artifactId>
              <version>1.3.4</version>
          </dependency>

          二,范例代碼

          下面我們以二分類問(wèn)題為例,按照如下幾個(gè)大家熟悉的步驟進(jìn)行范例代碼演示。
          • 1,準(zhǔn)備數(shù)據(jù)
          • 2,定義模型
          • 3,訓(xùn)練模型
          • 4,評(píng)估模型
          • 5,使用模型
          • 6,保存模型
          import org.apache.spark.sql.SparkSession
          import org.apache.spark.sql.DataFrame
          import org.apache.spark.sql.types.{DoubleTypeStringTypeStructFieldStructTypeIntegerType}
          import org.apache.spark.ml.Pipeline
          import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
          import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
          import org.apache.spark.ml.linalg.Vector
          import org.apache.spark.ml.feature.VectorAssembler
          import org.apache.spark.ml.attribute.Attribute
          import org.apache.spark.ml.feature.{IndexToStringStringIndexer}
          import com.microsoft.ml.spark.{lightgbm=>lgb}
          import com.google.gson.{JsonObjectJsonParser}
          import scala.collection.JavaConverters._

          object LgbDemo extends Serializable {
              
              def printlog(info:String): Unit ={
                  val dt = new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new java.util.Date)
                  println("=========="*8+dt)
                  println(info+"\n")
              }
              
              def main(args:Array[String]):Unit= {


              /*================================================================================*/
              //  一,加載數(shù)據(jù)
              /*================================================================================*/
              printlog("step1: preparing data ...")

              //加載數(shù)據(jù)
              val spark = SparkSession.builder().getOrCreate()
              val dfdata_raw = spark.read.option("header","true")
                  .option("delimiter""\t")
                  .option("inferschema""true")
                  .option("nullValue","")
                  .csv("data/breast_cancer.csv")

              dfdata_raw.sample(false,0.1,1).printSchema 

              //將特征組合成features向量
              val feature_cols = dfdata_raw.columns.filter(!Array("label").contains(_)) 
              val cate_cols = Array("mean_radius","mean_texture"


              val vectorAssembler = new VectorAssembler().
                setInputCols(feature_cols).
                setOutputCol("features")

              val dfdata = vectorAssembler.transform(dfdata_raw).select("features""label")
              val Array(dftrain,dfval)  = dfdata.randomSplit(Array(0.7.3), 666)

              //各個(gè)特征的名字存儲(chǔ)在了schema 的 metadata中了, 所以可以用特征名指定類別特征 
              println(dfdata.schema("features").metadata)
              dfdata.show(10

              /*================================================================================*/
              //  二,定義模型
              /*================================================================================*/
              printlog("step2: defining model ...")

              val lgbclassifier = new lgb.LightGBMClassifier()
                .setNumIterations(100)
                .setLearningRate(0.1)
                .setNumLeaves(31)
                .setMinSumHessianInLeaf(0.001)
                .setMaxDepth(-1)
                .setBoostFromAverage(false)
                .setFeatureFraction(1.0)
                .setMaxBin(255)
                .setLambdaL1(0.0)
                .setLambdaL2(0.0)
                .setBaggingFraction(1.0)
                .setBaggingFreq(0)
                .setBaggingSeed(1)
                .setBoostingType("gbdt"//rf、dart、goss
                .setCategoricalSlotNames(cate_cols)
                .setObjective("binary"//binary, multiclass
                .setFeaturesCol("features"
                .setLabelCol("label")

              println(lgbclassifier.explainParams) 


              /*================================================================================*/
              //  三,訓(xùn)練模型
              /*================================================================================*/
              printlog("step3: training model ...")

              val lgbmodel = lgbclassifier.fit(dftrain)

              val feature_importances = lgbmodel.getFeatureImportances("gain")
              val arr = feature_cols.zip(feature_importances).sortBy[Double](t=> -t._2)
              val dfimportance = spark.createDataFrame(arr).toDF("feature_name","feature_importance(gain)")

              dfimportance.show(100)


              /*================================================================================*/
              //  四,評(píng)估模型
              /*================================================================================*/
              printlog("step4: evaluating model ...")

              val evaluator = new BinaryClassificationEvaluator()
                .setLabelCol("label")
                .setRawPredictionCol("rawPrediction")
                .setMetricName("areaUnderROC")

              val dftrain_result = lgbmodel.transform(dftrain)
              val dfval_result = lgbmodel.transform(dfval)

              val train_auc  = evaluator.evaluate(dftrain_result)
              val val_auc = evaluator.evaluate(dfval_result)
              println(s"train_auc = ${train_auc}")
              println(s"val_auc = ${val_auc}")


              /*================================================================================*/
              //  五,使用模型
              /*================================================================================*/
              printlog("step5: using model ...")

              //批量預(yù)測(cè)
              val dfpredict = lgbmodel.transform(dfval)
              dfpredict.sample(false,0.1,1).show(20)

              //對(duì)單個(gè)樣本進(jìn)行預(yù)測(cè)
              val features = dfval.head().getAs[Vector]("features")
              val single_result = lgbmodel.predict(features)

              println(single_result)


              /*================================================================================*/
              //  六,保存模型
              /*================================================================================*/
              printlog("step6: saving model ...")

              //保存到集群,多文件
              lgbmodel.write.overwrite().save("lgbmodel.model")
              //加載集群模型
              println("load model ...")
              val lgbmodel_loaded = lgb.LightGBMClassificationModel.load("lgbmodel.model")
              val dfresult = lgbmodel_loaded.transform(dfval)
              dfresult.show() 

              //保存到本地,單文件,和Python接口兼容
              //lgbmodel.saveNativeModel("lgb_model",true)
              //加載本地模型
              //val lgbmodel_loaded = LightGBMClassificationModel.loadNativeModelFromFile("lgb_model")
              
              }
              
          }

          三,輸出參考

          運(yùn)行如上代碼之后,可以得到如下輸出。
          注意 println(lgbclassifier.explainParams)可以獲取LightGBM模型各個(gè)參數(shù)的含義以及默認(rèn)值。
          ================================================================================2021-07-17 22:16:29
          step1: preparing data ...

          root
          |-- mean_radius: integer (nullable = true)
          |-- mean_texture: integer (nullable = true)
          |-- mean_perimeter: double (nullable = true)
          |-- mean_area: double (nullable = true)
          |-- mean_smoothness: double (nullable = true)
          |-- mean_compactness: double (nullable = true)
          |-- mean_concavity: double (nullable = true)
          |-- mean_concave_points: double (nullable = true)
          |-- mean_symmetry: double (nullable = true)
          |-- mean_fractal_dimension: double (nullable = true)
          |-- radius_error: double (nullable = true)
          |-- texture_error: double (nullable = true)
          |-- perimeter_error: double (nullable = true)
          |-- area_error: double (nullable = true)
          |-- smoothness_error: double (nullable = true)
          |-- compactness_error: double (nullable = true)
          |-- concavity_error: double (nullable = true)
          |-- concave_points_error: double (nullable = true)
          |-- symmetry_error: double (nullable = true)
          |-- fractal_dimension_error: double (nullable = true)
          |-- worst_radius: double (nullable = true)
          |-- worst_texture: double (nullable = true)
          |-- worst_perimeter: double (nullable = true)
          |-- worst_area: double (nullable = true)
          |-- worst_smoothness: double (nullable = true)
          |-- worst_compactness: double (nullable = true)
          |-- worst_concavity: double (nullable = true)
          |-- worst_concave_points: double (nullable = true)
          |-- worst_symmetry: double (nullable = true)
          |-- worst_fractal_dimension: double (nullable = true)
          |-- label: integer (nullable = true)

          {"ml_attr":{"attrs":{"numeric":[{"idx":0,"name":"mean_radius"},{"idx":1,"name":"mean_texture"},{"idx":2,"name":"mean_perimeter"},{"idx":3,"name":"mean_area"},{"idx":4,"name":"mean_smoothness"},{"idx":5,"name":"mean_compactness"},{"idx":6,"name":"mean_concavity"},{"idx":7,"name":"mean_concave_points"},{"idx":8,"name":"mean_symmetry"},{"idx":9,"name":"mean_fractal_dimension"},{"idx":10,"name":"radius_error"},{"idx":11,"name":"texture_error"},{"idx":12,"name":"perimeter_error"},{"idx":13,"name":"area_error"},{"idx":14,"name":"smoothness_error"},{"idx":15,"name":"compactness_error"},{"idx":16,"name":"concavity_error"},{"idx":17,"name":"concave_points_error"},{"idx":18,"name":"symmetry_error"},{"idx":19,"name":"fractal_dimension_error"},{"idx":20,"name":"worst_radius"},{"idx":21,"name":"worst_texture"},{"idx":22,"name":"worst_perimeter"},{"idx":23,"name":"worst_area"},{"idx":24,"name":"worst_smoothness"},{"idx":25,"name":"worst_compactness"},{"idx":26,"name":"worst_concavity"},{"idx":27,"name":"worst_concave_points"},{"idx":28,"name":"worst_symmetry"},{"idx":29,"name":"worst_fractal_dimension"}]},"num_attrs":30}}
          +--------------------+-----+
          | features|label|
          +--------------------+-----+
          |[17.0,10.0,122.8,...| 0|
          |[20.0,17.0,132.9,...| 0|
          |[19.0,21.0,130.0,...| 0|
          |[11.0,20.0,77.58,...| 0|
          |[20.0,14.0,135.1,...| 0|
          |[12.0,15.0,82.57,...| 0|
          |[18.0,19.0,119.6,...| 0|
          |[13.0,20.0,90.2,5...| 0|
          |[13.0,21.0,87.5,5...| 0|
          |[12.0,24.0,83.97,...| 0|
          +--------------------+-----+
          only showing top 10 rows

          ================================================================================2021-07-17 22:16:29
          step2: defining model ...

          baggingFraction: Bagging fraction (default: 1.0, current: 1.0)
          baggingFreq: Bagging frequency (default: 0, current: 0)
          baggingSeed: Bagging seed (default: 3, current: 1)
          boostFromAverage: Adjusts initial score to the mean of labels for faster convergence (default: true, current: false)
          boostingType: Default gbdt = traditional Gradient Boosting Decision Tree. Options are: gbdt, gbrt, rf (Random Forest), random_forest, dart (Dropouts meet Multiple Additive Regression Trees), goss (Gradient-based One-Side Sampling). (default: gbdt, current: gbdt)
          categoricalSlotIndexes: List of categorical column indexes, the slot index in the features column (undefined)
          categoricalSlotNames: List of categorical column slot names, the slot name in the features column (current: [Ljava.lang.String;@351fb3fc)
          defaultListenPort: The default listen port on executors, used for testing (default: 12400)
          earlyStoppingRound: Early stopping round (default: 0)
          featureFraction: Feature fraction (default: 1.0, current: 1.0)
          featuresCol: features column name (default: features, current: features)
          initScoreCol: The name of the initial score column, used for continued training (undefined)
          isProvideTrainingMetric: Whether output metric result over training dataset. (default: false)
          isUnbalance: Set to true if training data is unbalanced in binary classification scenario (default: false)
          labelCol: label column name (default: label, current: label)
          lambdaL1: L1 regularization (default: 0.0, current: 0.0)
          lambdaL2: L2 regularization (default: 0.0, current: 0.0)
          learningRate: Learning rate or shrinkage rate (default: 0.1, current: 0.1)
          maxBin: Max bin (default: 255, current: 255)
          maxDepth: Max depth (default: -1, current: -1)
          minSumHessianInLeaf: Minimal sum hessian in one leaf (default: 0.001, current: 0.001)
          modelString: LightGBM model to retrain (default: )
          numBatches: If greater than 0, splits data into separate batches during training (default: 0)
          numIterations: Number of iterations, LightGBM constructs num_class * num_iterations trees (default: 100, current: 100)
          numLeaves: Number of leaves (default: 31, current: 31)
          objective: The Objective. For regression applications, this can be: regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. For classification applications, this can be: binary, multiclass, or multiclassova. (default: binary, current: binary)
          parallelism: Tree learner parallelism, can be set to data_parallel or voting_parallel (default: data_parallel)
          predictionCol: prediction column name (default: prediction)
          probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
          rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
          thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
          timeout: Timeout in seconds (default: 1200.0)
          useBarrierExecutionMode: Use new barrier execution mode in Beta testing, off by default. (default: false)
          validationIndicatorCol: Indicates whether the row is for training or validation (undefined)
          verbosity: Verbosity where lt 0 is Fatal, eq 0 is Error, eq 1 is Info, gt 1 is Debug (default: 1)
          weightCol: The name of the weight column (undefined)
          ================================================================================2021-07-17 22:16:29
          step3: training model ...

          +--------------------+------------------------+
          | feature_name|feature_importance(gain)|
          +--------------------+------------------------+
          | worst_area| 974.9349449056517|
          | worst_perimeter| 885.3691593843923|
          |worst_concave_points| 255.67364284247745|
          | mean_concave_points| 250.21955942230738|
          | worst_texture| 151.07745621304454|
          | area_error| 65.75557372416814|
          | worst_smoothness| 62.29973236144293|
          | mean_smoothness| 19.902610011957194|
          | worst_radius| 16.8275272153341|
          | mean_area| 12.41261211467938|
          | mean_perimeter| 12.127510878875537|
          | worst_concavity| 11.414242858900646|
          | compactness_error| 10.996194651604892|
          | mean_texture| 9.274276675339683|
          | concavity_error| 8.009578698471008|
          | symmetry_error| 7.93458393366217|
          | radius_error| 7.357747321194173|
          | worst_symmetry| 5.951699663755868|
          |fractal_dimension...| 4.811246624133022|
          |concave_points_error| 4.73140145466917|
          | worst_compactness| 4.469820723182832|
          | texture_error| 4.356178728700959|
          | mean_compactness| 3.123736411467967|
          | mean_symmetry| 1.9968633063354835|
          | mean_concavity| 1.9701941942285224|
          | smoothness_error| 1.673042485476758|
          |worst_fractal_dim...| 1.3582115541525612|
          |mean_fractal_dime...| 0.6050912755332459|
          | perimeter_error| 0.3889888676278275|
          | mean_radius| 5.684356116234315...|
          +--------------------+------------------------+

          ================================================================================2021-07-17 22:16:30
          step4: evaluating model ...

          train_auc = 1.0
          val_auc = 0.9890340267698758
          ================================================================================2021-07-17 22:16:31
          step5: using model ...

          +--------------------+-----+--------------------+--------------------+----------+
          | features|label| rawPrediction| probability|prediction|
          +--------------------+-----+--------------------+--------------------+----------+
          |[9.0,12.0,60.34,2...| 1|[-10.570726382467...|[-9.5707263824679...| 1.0|
          |[10.0,16.0,65.85,...| 1|[-10.120435089856...|[-9.1204350898567...| 1.0|
          |[10.0,21.0,68.51,...| 1|[-8.8020346337692...|[-7.8020346337692...| 1.0|
          |[11.0,14.0,73.53,...| 1|[-10.315758226759...|[-9.3157582267596...| 1.0|
          |[11.0,15.0,73.38,...| 1|[-10.086077130817...|[-9.0860771308174...| 1.0|
          |[11.0,16.0,74.72,...| 1|[-6.9649803118554...|[-5.9649803118554...| 1.0|
          |[11.0,17.0,71.25,...| 1|[-10.694667171248...|[-9.6946671712481...| 1.0|
          |[11.0,17.0,75.27,...| 1|[-9.0156792680894...|[-8.0156792680894...| 1.0|
          |[11.0,18.0,75.17,...| 1|[-5.7513546284621...|[-4.7513546284621...| 1.0|
          |[11.0,18.0,76.38,...| 1|[-4.3134421808792...|[-3.3134421808792...| 1.0|
          |[12.0,15.0,82.57,...| 0|[2.49310942805160...|[3.49310942805160...| 0.0|
          |[12.0,17.0,78.27,...| 1|[-10.516042459712...|[-9.5160424597122...| 1.0|
          |[12.0,18.0,83.19,...| 1|[-9.4899850168431...|[-8.4899850168431...| 1.0|
          |[12.0,22.0,78.75,...| 1|[-8.9917629958319...|[-7.9917629958319...| 1.0|
          |[14.0,15.0,92.68,...| 1|[-7.2724968676775...|[-6.2724968676775...| 1.0|
          |[14.0,15.0,95.77,...| 1|[-5.0143190624015...|[-4.0143190624015...| 1.0|
          |[14.0,16.0,96.22,...| 1|[-5.3849620427583...|[-4.3849620427583...| 1.0|
          |[14.0,19.0,97.83,...| 1|[-3.3292007261919...|[-2.3292007261919...| 1.0|
          |[16.0,14.0,104.3,...| 1|[4.66077729134426...|[5.66077729134426...| 0.0|
          |[19.0,24.0,122.0,...| 0|[10.1503565558166...|[11.1503565558166...| 0.0|
          +--------------------+-----+--------------------+--------------------+----------+

          1.0
          ================================================================================2021-07-17 22:16:31
          step6: saving model ...

          load model ...

          收工。??
          本文Spark-scala 使用 LightGBM 模型訓(xùn)練 二分類模型 代碼和數(shù)據(jù)集,以及訓(xùn)練 多分類模型 和 回歸模型 的范例代碼和數(shù)據(jù)集,可以在公眾號(hào)算法美食屋后臺(tái)回復(fù)關(guān)鍵詞 spark+lightgbm 獲取。
          萬(wàn)水千山總是情,點(diǎn)個(gè)在看行不行???


          猜你喜歡??
          30分鐘學(xué)會(huì)LightGBM
          瀏覽 90
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  先锋影音av成人版 | 亚洲免费网站 | 欧美性生交大片免费看A片免费 | 日本黄色视频。 | 亚洲欧美中文日韩在线观看 |