- 使用spark mllib 线性回归做车流量预测打印训练,权重及其他系数 全为NaN
数据格式:
520221|119|0009|223|292|000541875150|2018|04|18|11|3|137
520626|120|0038|223|140|203030001000|2018|04|18|11|3|119
520621|120|0024|223|005|000530002050|2018|04|18|11|3|91
最后一项为标签 车流量
2.代码如下
package com.spark;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import scala.Tuple2;
public class CarPassRegression {
public static void main(String[] args){
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR);//屏蔽日志
//入口
SparkConf conf= new SparkConf();
conf.setAppName("pass_regression").setMaster("local[*]")
.set("spark.sql.warehouse.dir","file:///");
JavaSparkContext sc = new JavaSparkContext(conf);
String trainDataPath ="E://test_data//target_carPass//traindata//*";
//
JavaRDD<String> rdd= sc.textFile(trainDataPath);
JavaRDD<LabeledPoint> traindata=rdd.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String s) throws Exception {
String [] part = s.split("\\|");
//获取label
double lable =Double.parseDouble(part[part.length-1]);
double [] features = new double[part.length-1];
for(int i=0;i<features.length;i++){
features[i] =Double.parseDouble(part[i]);
}
return new LabeledPoint(lable, Vectors.dense(features));
}
});
traindata.cache();
/*
训练模型
*/
int numIterations = 10000; //迭代次数
double stepSize = 0.000001;//步长
final LinearRegressionModel model= LinearRegressionWithSGD.
train(JavaRDD.toRDD(traindata),numIterations,stepSize);
System.out.println(model.weights()); //打印权重
//预测
JavaRDD<Tuple2<Double, Double>> valuesAndPreds = traindata.map(
new Function<LabeledPoint, Tuple2<Double, Double>>(){
public Tuple2<Double, Double> call(LabeledPoint point){
double prediction = model.predict(point.features());
return new Tuple2<Double, Double>(prediction, point.label());
}
}
);
//计算误差
double MSE = new JavaDoubleRDD(valuesAndPreds.map(
new Function<Tuple2<Double, Double>, Object>(){
public Object call(Tuple2<Double, Double> pair){
return Math.pow(pair._1() - pair._2(), 2.0);
}
}
).rdd()).mean();
System.out.println("training MeanSquared Error = " + MSE);
//模型评测
JavaRDD<Tuple2<Object, Object>> valuesAndPreds2= traindata.map(new Function<LabeledPoint, Tuple2<Object, Object>>(){
public Tuple2<Object, Object> call(LabeledPoint point)
throws Exception {
double prediction = model.predict(point.features());
return new Tuple2<Object, Object>(prediction, point.label());
}
});
RegressionMetrics metrics = new RegressionMetrics(JavaRDD.toRDD(valuesAndPreds2));
System.out.println("R2(决定系数)= "+metrics.r2());
System.out.println("MSE(均方差、方差) = "+metrics.meanSquaredError());
System.out.println("RMSE(均方根、标准差) "+metrics.rootMeanSquaredError());
System.out.println("MAE(平均绝对差值)= "+metrics.meanAbsoluteError());
// 保存加载模型
model.save(sc.sc(), "target/tmp/carPassLinearRegressionWithSGDModel");
LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(),
"target/tmp/carPassLinearRegressionWithSGDModel");
}
}
3.结果: