调用Spark mllib 线性回归 打印权重及其他系数 全为NaN

新手上路,请多包涵
  1. 使用spark mllib 线性回归做车流量预测打印训练,权重及其他系数 全为NaN

数据格式:
520221|119|0009|223|292|000541875150|2018|04|18|11|3|137
520626|120|0038|223|140|203030001000|2018|04|18|11|3|119
520621|120|0024|223|005|000530002050|2018|04|18|11|3|91

最后一项为标签 车流量

2.代码如下
package com.spark;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import scala.Tuple2;

public class CarPassRegression {

public static void main(String[] args){

    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR);//屏蔽日志
    //入口
    SparkConf conf= new SparkConf();
    conf.setAppName("pass_regression").setMaster("local[*]")
            .set("spark.sql.warehouse.dir","file:///");

    JavaSparkContext sc = new JavaSparkContext(conf);

    String trainDataPath ="E://test_data//target_carPass//traindata//*";
    //
    JavaRDD<String> rdd= sc.textFile(trainDataPath);

    JavaRDD<LabeledPoint> traindata=rdd.map(new Function<String, LabeledPoint>() {
        @Override
        public LabeledPoint call(String s) throws Exception {

            String [] part = s.split("\\|");
            //获取label
            double lable =Double.parseDouble(part[part.length-1]);

            double [] features = new double[part.length-1];
            for(int i=0;i<features.length;i++){
                features[i] =Double.parseDouble(part[i]);
            }

            return new LabeledPoint(lable, Vectors.dense(features));
        }
    });

    traindata.cache();

    /*
    训练模型
     */
    int numIterations = 10000;  //迭代次数
    double stepSize = 0.000001;//步长

    final LinearRegressionModel model= LinearRegressionWithSGD.
            train(JavaRDD.toRDD(traindata),numIterations,stepSize);

    System.out.println(model.weights()); //打印权重

    //预测
    JavaRDD<Tuple2<Double, Double>> valuesAndPreds = traindata.map(
            new Function<LabeledPoint, Tuple2<Double, Double>>(){
                public Tuple2<Double, Double> call(LabeledPoint point){
                    double prediction = model.predict(point.features());
                    return new Tuple2<Double, Double>(prediction, point.label());
                }
            }
    );

    //计算误差
    double MSE = new JavaDoubleRDD(valuesAndPreds.map(
            new Function<Tuple2<Double, Double>, Object>(){
                public Object call(Tuple2<Double, Double> pair){
                    return Math.pow(pair._1() - pair._2(), 2.0);
                }
            }
    ).rdd()).mean();
    System.out.println("training MeanSquared Error = " + MSE);

    //模型评测
    JavaRDD<Tuple2<Object, Object>>  valuesAndPreds2= traindata.map(new Function<LabeledPoint, Tuple2<Object, Object>>(){
        public Tuple2<Object, Object> call(LabeledPoint point)
                throws Exception {
            double prediction = model.predict(point.features());
            return new Tuple2<Object, Object>(prediction, point.label());
        }

    });
    RegressionMetrics metrics = new RegressionMetrics(JavaRDD.toRDD(valuesAndPreds2));
    System.out.println("R2(决定系数)= "+metrics.r2());
    System.out.println("MSE(均方差、方差) = "+metrics.meanSquaredError());
    System.out.println("RMSE(均方根、标准差) "+metrics.rootMeanSquaredError());
    System.out.println("MAE(平均绝对差值)= "+metrics.meanAbsoluteError());

    // 保存加载模型
    model.save(sc.sc(), "target/tmp/carPassLinearRegressionWithSGDModel");

    LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(),
            "target/tmp/carPassLinearRegressionWithSGDModel");




}

}
3.结果:

clipboard.png

阅读 3.9k
撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进