<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          Apache SparkMLlib構建機器學習分類模型

          共 27447字,需瀏覽 55分鐘

           ·

          2024-04-11 14:41

          一、引言

          1.1 Spark MLlib簡介

          Apache Spark MLlib(Machine Learning library)是一個開源機器學習框架,建立在Apache Spark之上,支持分布式計算和大規(guī)模數(shù)據(jù)處理。它提供了許多經(jīng)典機器學習算法和工具,如分類、回歸、聚類、協(xié)同過濾、特征提取和數(shù)據(jù)預處理等。

          Spark MLlib使用基于DataFrame的API,提供了一個易于使用的高級API,使得用戶能夠快速構建、訓練和調(diào)整機器學習模型,而無需擔心底層分布式計算的復雜性。它還支持分布式模型選擇和調(diào)整,以及與其他Apache Spark組件的集成,如Spark SQL、Spark Streaming和GraphX。

          Spark MLlib還提供了Python、Java和Scala等多種編程語言的API,使得不同的開發(fā)人員可以使用他們最喜歡的編程語言來開發(fā)機器學習應用程序。

          總之,Spark MLlib是一個非常強大和靈活的機器學習框架,適用于處理大規(guī)模數(shù)據(jù)和需要分布式計算的場景。

          1.2 為什么選擇使用Spark MLlib

          1. 處理大規(guī)模數(shù)據(jù):Spark MLlib支持分布式計算和大規(guī)模數(shù)據(jù)處理,使得處理大規(guī)模數(shù)據(jù)集變得容易。1. 豐富的算法庫:Spark MLlib包含了許多經(jīng)典的機器學習算法和工具,如分類、回歸、聚類、協(xié)同過濾、特征提取和數(shù)據(jù)預處理等,覆蓋了大部分機器學習應用場景。1. 高性能:Spark MLlib基于Apache Spark,使用內(nèi)存計算和RDD(彈性分布式數(shù)據(jù)集)等優(yōu)化技術,可以在處理大規(guī)模數(shù)據(jù)時提供高性能和可擴展性。1. 易于使用:Spark MLlib提供了一個易于使用的高級API,使得用戶可以快速構建、訓練和調(diào)整機器學習模型,而無需擔心底層分布式計算的復雜性。1. 多語言支持:Spark MLlib支持多種編程語言的API,包括Python、Java和Scala等,使得不同的開發(fā)人員可以使用他們最喜歡的編程語言來開發(fā)機器學習應用程序。

          二、Spark MLlib基礎

          2.1 RDD和DataFrame的比較

          1. 數(shù)據(jù)類型:基礎RDD可以包含任意類型的數(shù)據(jù),包括對象、原始類型、數(shù)組和集合等;DataFrame則是一種表格化的數(shù)據(jù)結構,其數(shù)據(jù)類型必須是統(tǒng)一的,且可以使用SQL-like的語法進行查詢。1. 內(nèi)存計算:DataFrame利用內(nèi)存計算技術,相比基礎RDD更加高效。1. 可讀性:DataFrame比基礎RDD更加易于閱讀和理解,可以使用SQL-like的語法進行查詢,更加直觀。1. 類型安全:DataFrame是類型安全的,可以在編譯期間捕獲類型錯誤,避免運行時錯誤;而基礎RDD則是類型不安全的,需要在運行時進行類型檢查。1. 執(zhí)行計劃:基礎RDD提供了更加靈活的執(zhí)行計劃,用戶可以控制計算的方式和順序,但這也增加了開發(fā)復雜度;而DataFrame則有一個自動優(yōu)化的執(zhí)行計劃,可以自動優(yōu)化查詢性能。 總之,基礎RDD更加靈活和可控,但需要開發(fā)人員自己掌握計算的方式和順序;而DataFrame則更加易于使用和高效,適合快速開發(fā)和迭代。選擇使用哪種數(shù)據(jù)結構,取決于具體的場景和需求。

          2.2 數(shù)據(jù)準備和預處理

          在使用Spark MLlib進行機器學習之前,需要對原始數(shù)據(jù)進行預處理和準備。以下是一些常見的數(shù)據(jù)準備和預處理步驟:

          1. 數(shù)據(jù)清洗:刪除缺失值、處理異常值和重復值等。1. 特征選擇:選擇對模型有用的特征,去除冗余和無關的特征。1. 特征縮放:對特征進行縮放,以便它們具有相似的范圍和重要性。1. 特征變換:將原始特征轉(zhuǎn)換為更有意義的特征,如使用對數(shù)、指數(shù)、平方根等函數(shù)進行變換。1. 特征歸一化:將特征值歸一化為標準正態(tài)分布,使得模型更容易學習。1. 數(shù)據(jù)轉(zhuǎn)換:將數(shù)據(jù)轉(zhuǎn)換為適合模型訓練的格式,如將分類變量轉(zhuǎn)換為二進制變量、將文本轉(zhuǎn)換為向量等。 在Spark MLlib中,可以使用各種預處理和數(shù)據(jù)準備工具,如:
          2. Imputer:用于填充缺失值。1. StandardScaler:用于特征縮放和歸一化。1. VectorAssembler:用于將多個特征列組合成一個向量列。1. OneHotEncoder:用于將分類變量轉(zhuǎn)換為二進制變量。1. StringIndexer和IndexToString:用于將字符串類型的變量轉(zhuǎn)換為數(shù)字類型的變量。1. Tokenizer和StopWordsRemover:用于將文本轉(zhuǎn)換為向量。 總之,在使用Spark MLlib進行機器學習之前,需要對原始數(shù)據(jù)進行預處理和準備。Spark MLlib提供了許多工具和功能,可以幫助我們輕松地完成這些任務。

          2.3 特征提取和轉(zhuǎn)換

          在Spark MLlib中,有許多常用的特征提取和轉(zhuǎn)換工具,包括:

          1. Tokenizer:用于將文本轉(zhuǎn)換為單詞或詞條。1. StopWordsRemover:用于去除文本中的停用詞,如“the”、“and”等。1. CountVectorizer:用于將文本轉(zhuǎn)換為詞頻向量。1. HashingTF:用于將文本轉(zhuǎn)換為哈希向量,可以減少維度并提高計算效率。1. IDF:用于計算逆文檔頻率,可以減少常見詞語的權重,提高稀有詞語的權重。1. Word2Vec:用于將文本轉(zhuǎn)換為向量,可以捕捉詞語之間的語義關系。1. PCA:用于將高維特征空間降維,可以提高計算效率并避免過擬合。1. StringIndexer:用于將分類變量轉(zhuǎn)換為數(shù)字類型的變量。1. OneHotEncoder:用于將數(shù)字類型的變量轉(zhuǎn)換為二進制變量。 以上這些工具都可以用于特征提取和轉(zhuǎn)換,幫助我們將原始數(shù)據(jù)轉(zhuǎn)換為模型可以處理的格式。我們可以根據(jù)具體的任務和數(shù)據(jù)類型選擇適當?shù)墓ぞ撸垣@得更好的結果。值得注意的是,這些工具的使用通常需要進行適當?shù)膮?shù)設置和調(diào)整,以達到最佳的效果。

          三、監(jiān)督學習

          3.1 分類問題

          3.1.1 邏輯回歸

          邏輯回歸是一種二元分類模型,它的目標是根據(jù)已知數(shù)據(jù)對一個事物進行分類。邏輯回歸的輸出是一個概率值,代表該事物屬于某個類別的概率。如果概率值大于閾值,則將其分類為正類,否則分類為負類。

          在 Spark MLlib 中,可以使用 LogisticRegression 類來實現(xiàn)邏輯回歸。下面是一個 Java 版本的示例代碼:

          pom引用:

                
                <dependencies>
              <!-- Spark core dependencies -->
              <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-core_2.12</artifactId>
                <version>3.2.0</version>
              </dependency>
              <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-sql_2.12</artifactId>
                <version>3.2.0</version>
              </dependency>
              <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-mllib_2.12</artifactId>
                <version>3.2.0</version>
              </dependency>

              <!-- Spark testing dependencies (optional) -->
              <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-streaming_2.12</artifactId>
                <version>3.2.0</version>
                <scope>test</scope>
              </dependency>
              <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-streaming-kafka-0-10_2.12</artifactId>
                <version>3.2.0</version>
                <scope>test</scope>
              </dependency>
              <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-sql-kafka-0-10_2.12</artifactId>
                <version>3.2.0</version>
                <scope>test</scope>
              </dependency>
            </dependencies>
                
                import org.apache.spark.ml.classification.LogisticRegression;
          import org.apache.spark.ml.classification.LogisticRegressionModel;
          import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
          import org.apache.spark.ml.feature.VectorAssembler;
          import org.apache.spark.ml.linalg.Vector;
          import org.apache.spark.sql.Dataset;
          import org.apache.spark.sql.Row;
          import org.apache.spark.sql.SparkSession;

          public class LogisticRegressionDemo {

              public static void main(String[] args) {
                  SparkSession spark = SparkSession
                          .builder()
                          .appName("LogisticRegressionDemo")
                          .master("local[*]")
                          .getOrCreate();

                  // 加載數(shù)據(jù)
                  Dataset<Row> data = spark.read().format("libsvm").load("data/sample_libsvm_data.txt");

                  // 將特征向量轉(zhuǎn)換成一列
                  VectorAssembler assembler = new VectorAssembler()
                          .setInputCols(new String[]{"features"})
                          .setOutputCol("feature");

                  Dataset<Row> newData = assembler.transform(data).select("label""feature");

                  // 將數(shù)據(jù)集分為訓練集和測試集
                  Dataset<Row>[] splits = newData.randomSplit(new double[]{0.70.3});
                  Dataset<Row> trainData = splits[0];
                  Dataset<Row> testData = splits[1];

                  // 創(chuàng)建邏輯回歸模型
                  LogisticRegression lr = new LogisticRegression();

                  // 訓練模型
                  LogisticRegressionModel lrModel = lr.fit(trainData);

                  // 在測試集上進行預測
                  Dataset<Row> predictions = lrModel.transform(testData);

                  // 計算模型評估指標
                  BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
                  double auc = evaluator.evaluate(predictions);

                  System.out.println("Area under ROC curve = " + auc);

                  spark.stop();
              }
          }

          這個示例代碼首先加載了一個 libsvm 格式的數(shù)據(jù)集,然后將特征向量轉(zhuǎn)換成一列,將數(shù)據(jù)集分為訓練集和測試集,創(chuàng)建邏輯回歸模型并訓練模型,最后在測試集上進行預測并計算模型評估指標。在這個例子中,我們使用了 BinaryClassificationEvaluator 來計算模型的 AUC 指標,它是評估二元分類器性能的一種常用指標。

          需要注意的是,以上代碼僅供參考,實際情況可能需要根據(jù)數(shù)據(jù)集的特點和任務的要求進行相應的修改。

          3.1.2 決策樹

          Spark MLlib 分類決策樹是一種基于樹結構的分類算法,通過一系列特征對數(shù)據(jù)進行劃分和分類。該算法在 Spark MLlib 中的實現(xiàn)采用 CART(Classification And Regression Tree)算法,使用信息熵或 Gini 系數(shù)等指標進行特征選擇和劃分。Spark MLlib 分類決策樹可用于二分類、多分類和概率預測問題。

                
                import org.apache.spark.ml.Pipeline;
          import org.apache.spark.ml.PipelineModel;
          import org.apache.spark.ml.PipelineStage;
          import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
          import org.apache.spark.ml.classification.DecisionTreeClassifier;
          import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
          import org.apache.spark.ml.feature.IndexToString;
          import org.apache.spark.ml.feature.StringIndexer;
          import org.apache.spark.ml.feature.StringIndexerModel;
          import org.apache.spark.ml.feature.VectorAssembler;
          import org.apache.spark.sql.Dataset;
          import org.apache.spark.sql.Row;
          import org.apache.spark.sql.SparkSession;

          public class DecisionTreeClassificationExample {
            public static void main(String[] args) {
              SparkSession spark = SparkSession.builder()
                .appName("DecisionTreeClassificationExample")
                .master("local[*]")
                .getOrCreate();

              // 讀取數(shù)據(jù)集
              Dataset<Row> data = spark.read().format("csv")
                .option("header""true")
                .option("inferSchema""true")
                .load("path/to/data.csv");

              // 將標簽列轉(zhuǎn)換為數(shù)值類型
              StringIndexerModel labelIndexer = new StringIndexer()
                .setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(data);
              data = labelIndexer.transform(data);

              // 將特征列轉(zhuǎn)換為特征向量
              VectorAssembler featureAssembler = new VectorAssembler()
                .setInputCols(new String[]{"feature1""feature2""feature3"})
                .setOutputCol("features");
              data = featureAssembler.transform(data);

              // 將數(shù)據(jù)集分為訓練集和測試集
              Dataset<Row>[] splits = data.randomSplit(new double[]{0.70.3}, 12345);
              Dataset<Row> trainData = splits[0];
              Dataset<Row> testData = splits[1];

              // 創(chuàng)建決策樹分類器
              DecisionTreeClassifier dt = new DecisionTreeClassifier()
                .setLabelCol("indexedLabel")
                .setFeaturesCol("features");

              // 將標簽數(shù)值轉(zhuǎn)換回原始標簽
              IndexToString labelConverter = new IndexToString()
                .setInputCol("prediction")
                .setOutputCol("predictedLabel")
                .setLabels(labelIndexer.labels());

              // 創(chuàng)建管道并擬合模型
              Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{labelIndexer, featureAssembler, dt, labelConverter});
              PipelineModel model = pipeline.fit(trainData);

              // 在測試集上進行預測和評估
              Dataset<Row> predictions = model.transform(testData);
              MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("indexedLabel")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");
              double accuracy = evaluator.evaluate(predictions);
              System.out.println("Test Error = " + (1.0 - accuracy));
              // 輸出決策樹結構
              DecisionTreeClassificationModel treeModel =
              (DecisionTreeClassificationModel) (model.stages()[2]);
              System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());

              spark.stop();
              }
          }

          以上示例中,我們首先使用 SparkSession 讀取 CSV 格式的數(shù)據(jù)集。然后,使用 StringIndexer 將標簽列轉(zhuǎn)換為數(shù)值類型,并使用 VectorAssembler 將特征列轉(zhuǎn)換為特征向量。接著,將數(shù)據(jù)集分為訓練集和測試集,并創(chuàng)建 DecisionTreeClassifier 決策樹分類器。最后,將管道中的各個階段組合在一起,擬合模型并在測試集上進行預測和評估。

          3.1.3 隨機森林

          隨機森林是一種集成學習算法,它將多棵決策樹組合起來,通過投票或平均來決定分類結果。該算法在 Spark MLlib 中的實現(xiàn)使用基于 CART(Classification And Regression Tree)算法的決策樹作為基分類器,可以用于二分類、多分類和概率預測問題。

          以下是一個基于 Java 的 Spark MLlib 分類隨機森林示例:

                
                import org.apache.spark.ml.Pipeline;
          import org.apache.spark.ml.PipelineModel;
          import org.apache.spark.ml.PipelineStage;
          import org.apache.spark.ml.classification.RandomForestClassificationModel;
          import org.apache.spark.ml.classification.RandomForestClassifier;
          import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
          import org.apache.spark.ml.feature.IndexToString;
          import org.apache.spark.ml.feature.StringIndexer;
          import org.apache.spark.ml.feature.StringIndexerModel;
          import org.apache.spark.ml.feature.VectorAssembler;
          import org.apache.spark.sql.Dataset;
          import org.apache.spark.sql.Row;
          import org.apache.spark.sql.SparkSession;

          public class RandomForestClassificationExample {
            public static void main(String[] args) {
              SparkSession spark = SparkSession.builder()
                .appName("RandomForestClassificationExample")
                .master("local[*]")
                .getOrCreate();

              // 讀取數(shù)據(jù)集
              Dataset<Row> data = spark.read().format("csv")
                .option("header""true")
                .option("inferSchema""true")
                .load("path/to/data.csv");

              // 將標簽列轉(zhuǎn)換為數(shù)值類型
              StringIndexerModel labelIndexer = new StringIndexer()
                .setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(data);
              data = labelIndexer.transform(data);

              // 將特征列轉(zhuǎn)換為特征向量
              VectorAssembler featureAssembler = new VectorAssembler()
                .setInputCols(new String[]{"feature1""feature2""feature3"})
                .setOutputCol("features");
              data = featureAssembler.transform(data);

              // 將數(shù)據(jù)集分為訓練集和測試集
              Dataset<Row>[] splits = data.randomSplit(new double[]{0.70.3}, 12345);
              Dataset<Row> trainData = splits[0];
              Dataset<Row> testData = splits[1];

              // 創(chuàng)建隨機森林分類器
              RandomForestClassifier rf = new RandomForestClassifier()
                .setLabelCol("indexedLabel")
                .setFeaturesCol("features")
                .setNumTrees(10);

              // 將標簽數(shù)值轉(zhuǎn)換回原始標簽
              IndexToString labelConverter = new IndexToString()
                .setInputCol("prediction")
                .setOutputCol("predictedLabel")
                .setLabels(labelIndexer.labels());

              // 創(chuàng)建管道并擬合模型
              Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{labelIndexer, featureAssembler, rf, labelConverter});
              PipelineModel model = pipeline.fit(trainData);

              // 在測試集上進行預測和評估
              Dataset<Row> predictions = model.transform(testData);
              MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("indexedLabel")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");
              double accuracy = evaluator.evaluate(predictions);
              System.out.println("Test Error = " + (1.0 - accuracy));
              // 獲取訓練好的隨機森林模型并打印樹的重要性
              RandomForestClassificationModel rfModel = (RandomForestClassificationModel) model.stages()[2];
              System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());

              spark.stop();
            }
          }

          該示例代碼首先使用 SparkSession 讀取 CSV 格式的數(shù)據(jù)集。接下來,使用 StringIndexer 將標簽列轉(zhuǎn)換為數(shù)值類型,并使用 VectorAssembler 將特征列轉(zhuǎn)換為特征向量。然后,將數(shù)據(jù)集分為訓練集和測試集。創(chuàng)建 RandomForestClassifier,并將其作為管道的一部分進行擬合。擬合后,使用 MulticlassClassificationEvaluator 對測試集進行預測和評估。最后,獲取訓練好的隨機森林模型并打印樹的重要性。

          請注意,上面的示例中,數(shù)據(jù)集的路徑應該被替換為實際數(shù)據(jù)集的路徑,特征列的名稱也應該被替換為實際特征列的名稱。

          3.1.4 梯度提升樹

          Spark MLlib 提供了一個強大的算法——分類梯度提升樹(Gradient-Boosted Trees, GBT),它可以用于二元分類和多類分類。GBT 是一種集成學習算法,它通過在先前樹的殘差上逐步擬合一系列決策樹來提高模型的準確性。

          在 Spark MLlib 中,可以使用 GBTClassifier 類來構建分類 GBT 模型。GBT 分類器使用一系列決策樹來逐步提高模型的準確性,每個決策樹都是在之前決策樹的殘差上訓練得到的。通過這種方式,GBT 可以在更少的迭代次數(shù)下得到比隨機森林更準確的模型。

          與其他 Spark MLlib 分類器類似,GBT 分類器也使用管道(Pipeline)來處理數(shù)據(jù)。管道通常包括以下幾個步驟:

          1. 數(shù)據(jù)預處理:包括數(shù)據(jù)清洗、特征提取、特征轉(zhuǎn)換等操作。1. 特征工程:根據(jù)特定的特征工程需求,對特征進行過濾、選擇、轉(zhuǎn)換等操作。1. 模型訓練:使用訓練集對模型進行擬合。1. 模型評估:使用測試集對模型進行評估。1. 模型應用:將模型應用到新的數(shù)據(jù)集上進行預測。 在使用 GBT 分類器時,你需要指定以下參數(shù):
          • featuresCol:特征列的名稱。- labelCol:標簽列的名稱。- maxIter:訓練迭代次數(shù)。- maxDepth:決策樹的最大深度。- minInstancesPerNode:每個節(jié)點上的最小實例數(shù)。- stepSize:每個迭代步驟的步長。- subsamplingRate:用于訓練每棵樹的數(shù)據(jù)子樣本的比例。
                
                import org.apache.spark.ml.Pipeline;
          import org.apache.spark.ml.PipelineModel;
          import org.apache.spark.ml.PipelineStage;
          import org.apache.spark.ml.classification.GBTClassificationModel;
          import org.apache.spark.ml.classification.GBTClassifier;
          import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
          import org.apache.spark.ml.feature.IndexToString;
          import org.apache.spark.ml.feature.StringIndexer;
          import org.apache.spark.ml.feature.StringIndexerModel;
          import org.apache.spark.ml.feature.VectorIndexer;
          import org.apache.spark.ml.feature.VectorIndexerModel;
          import org.apache.spark.sql.Dataset;
          import org.apache.spark.sql.Row;
          import org.apache.spark.sql.SparkSession;


          public class GBTExample {
              public static void main(String[] args) {
                  // 創(chuàng)建一個 SparkSession
                  SparkSession spark = SparkSession
                          .builder()
                          .appName("GBTExample")
                          .getOrCreate();

                  // 讀取數(shù)據(jù)集
                  Dataset<Row> data = spark.read()
                          .format("libsvm")
                          .load("data/mllib/sample_libsvm_data.txt");

                  // 對標簽列進行索引
                  StringIndexerModel labelIndexer = new StringIndexer()
                          .setInputCol("label")
                          .setOutputCol("indexedLabel")
                          .fit(data);

                  // 對特征列進行索引
                  VectorIndexerModel featureIndexer = new VectorIndexer()
                          .setInputCol("features")
                          .setOutputCol("indexedFeatures")
                          .setMaxCategories(4// 特征具有少于 4 個不同的值
                          .fit(data);

                  // 將數(shù)據(jù)集拆分為訓練集和測試集
                  Dataset<Row>[] splits = data.randomSplit(new double[]{0.70.3});
                  Dataset<Row> trainingData = splits[0];
                  Dataset<Row> testData = splits[1];

                  // 定義 GBT 分類器
                  GBTClassifier gbt = new GBTClassifier()
                          .setLabelCol("indexedLabel")
                          .setFeaturesCol("indexedFeatures")
                          .setMaxIter(10)
                          .setFeatureSubsetStrategy("auto");

                  // 將索引的標簽轉(zhuǎn)換回原始標簽
                  IndexToString labelConverter = new IndexToString()
                          .setInputCol("prediction")
                          .setOutputCol("predictedLabel")
                          .setLabels(labelIndexer.labels());

                  // 創(chuàng)建管道
                  Pipeline pipeline = new Pipeline()
                          .setStages(new PipelineStage[]{
                                  labelIndexer,
                                  featureIndexer,
                                  gbt,
                                  labelConverter
                          });

                  // 訓練模型
                  PipelineModel model = pipeline.fit(trainingData);

                  // 進行預測
                  Dataset<Row> predictions = model.transform(testData);

                  // 選擇樣例行顯示
                  predictions.select("predictedLabel""label""features").show(5);

                  // 評估模型
                  MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                          .setLabelCol("indexedLabel")
                          .setPredictionCol("prediction")
                          .setMetricName("accuracy");
                  double accuracy = evaluator.evaluate(predictions);
                  System.out.println("Test Error = " + (1.0 - accuracy));

                  // 獲取訓練得到的 GBT 模型
                  GBTClassificationModel gbtModel = (GBTClassificationModel) (model.stages()[2]);
                  System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());

                  spark.stop();
              }
          }

          該示例使用了 Spark MLlib 內(nèi)置的 sample_libsvm_data.txt 數(shù)據(jù)集。首先,將數(shù)據(jù)集加載到 DataFrame 中。接下來,對標簽列和特征列進行索引。然后,將數(shù)據(jù)集拆分為訓練集和測試集。接下來,創(chuàng)建 GBT 分類器,并使用管道將標簽轉(zhuǎn)換回原始標簽。最后,使用訓練數(shù)據(jù)擬合管道并進行預測。最終評估模型并輸出模型學習到的 GBT 分類模型的調(diào)試字符串。該字符串顯示了樹的結構和分裂標準,以及在每個節(jié)點處對特征的使用情況和分裂點。


          瀏覽 43
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  国产福利91 | 青青草超碰在线 | 六区,七区视频在线播放 | 日韩成人拍拍视频在线 | 欧美在线看片 |