点击数:3

1.前言

最近公司项目有需要用到在Java中计算一元线性回归的功能,网上找了很久,发现一篇不错的文章,但是原文的方法计算出来和Excel计算的最终结果总是有一点的误差,所以我在原文的代码上做了一点修改,最终的结果和Excel计算出来的基本没有误差了。下面是原文地址:

原文链接

2.内容

2.1 定义实体类

定义一个DataPoint类,对X和Y坐标点进行封装:

/**
 * Description : Java实现一元线性回归的算法,座标点实体类,(可实现统计指标的预测)
 */
public class DataPoint {

    /** the x value */
    public double x;

    /** the y value */
    public double y;

    /**
     * Constructor.
     * 
     * @param x
     *            the x value
     * @param y
     *            the y value
     */
    public DataPoint(double x, double y) {
        this.x = x;
        this.y = y;
    }
}

2.2 回归线实现类

import java.math.BigDecimal;
import java.util.ArrayList;

/**
 * <p>
 * <b>Linear Regression</b> <br>
 * 通过构建一个集合的回归线来演示线性回归的数据点
 * <p>
 * require DataPoint.java,RegressionLine.java
 * 
 * <p>
 * 为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))
 * <p>
 * <b>回归直线方程如下: f(x)=a1x+a0 </b>
 * <p>
 * <b>斜率和截距的计算公式如下:</b> <br>
 * n: 数据点个数
 * <p>
 * a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2) <br>
 * a0=(SumY - SumY * a1)/n <br>
 * (也可表达为a0=averageY-a1*averageX)
 * 
 * <p>
 * <b>画线的原理:两点成一直线,只要能确定两个点即可</b><br>
 * 第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。
 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
 * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)
 * 
 * <p>
 * <b>拟合度计算:(即Excel中的R^2)</b>
 * <p>
 * *R2 = 1 - E
 * <p>
 * 误差E的计算:E = SSE/SST
 * <p>
 * SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
 * <p>
 */
public class RegressionLine // implements Evaluatable
{
    /** sum of x */
    private double sumX;

    /** sum of y */
    private double sumY;

    /** sum of x*x */
    private double sumXX;

    /** sum of x*y */
    private double sumXY;

    /** sum of y*y */
    private double sumYY;

    /** sum of yi-y */
    private double sumDeltaY;

    /** sum of sumDeltaY^2 */
    private double sumDeltaY2;

    /** 误差 */
    private double sse;

    private double sst;

    private double E;

    private String[] xy;

    private ArrayList<String> listX;

    private ArrayList<String> listY;

    private int XMin, XMax, YMin, YMax;

    /** 截距 a0 */
    private double a0;

    /** 斜率  a1 */
    private double a1;

    /** 数据点个数 */
    private int pn;

    /** true if coefficients valid */
    private boolean coefsValid;

    /**
     * 构造方法.
     */
    public RegressionLine() {
        XMax = 0;
        YMax = 0;
        pn = 0;
        xy = new String[2];
        listX = new ArrayList<String>();
        listY = new ArrayList<String>();
    }

    /**
     * Constructor.
     * 
     * @param data
     *            the array of data points
     */
    public RegressionLine(DataPoint data[]) {
        pn = 0;
        xy = new String[2];
        listX = new ArrayList<String>();
        listY = new ArrayList<String>();
        for (int i = 0; i < data.length; ++i) {
            addDataPoint(data[i]);
        }
    }

    /**
     * Return the current number of data points.
     * 
     * @return the count
     */
    public int getDataPointCount() {
        return pn;
    }

    /**
     * Return the coefficient a0.
     * 
     * @return the value of a0
     */
    public double getA0() {
        validateCoefficients();
        return a0;
    }

    /**
     * Return the coefficient a1.
     * 
     * @return the value of a1
     */
    public double getA1() {
        validateCoefficients();
        return a1;
    }

    /**
     * Return the sum of the x values.
     * 
     * @return the sum
     */
    public double getSumX() {
        return sumX;
    }

    /**
     * Return the sum of the y values.
     * 
     * @return the sum
     */
    public double getSumY() {
        return sumY;
    }

    /**
     * Return the sum of the x*x values.
     * 
     * @return the sum
     */
    public double getSumXX() {
        return sumXX;
    }

    /**
     * Return the sum of the x*y values.
     * 
     * @return the sum
     */
    public double getSumXY() {
        return sumXY;
    }

    public double getSumYY() {
        return sumYY;
    }

    public int getXMin() {
        return XMin;
    }

    public int getXMax() {
        return XMax;
    }

    public int getYMin() {
        return YMin;
    }

    public int getYMax() {
        return YMax;
    }

    /**
     * 添加一个新的数据点:更新总和.
     * 
     * @param dataPoint
     *            the new data point
     */
    public void addDataPoint(DataPoint dataPoint) {
        sumX += dataPoint.x;
        sumY += dataPoint.y;
        sumXX += dataPoint.x * dataPoint.x;
        sumXY += dataPoint.x * dataPoint.y;
        sumYY += dataPoint.y * dataPoint.y;

        if (dataPoint.x > XMax) {
            XMax = (int)dataPoint.x;
        }
        if (dataPoint.y > YMax) {
            YMax = (int)dataPoint.y;
        }

        // 把每个点的具体坐标存入ArrayList中,备用

        xy[0] = dataPoint.x + "";
        xy[1] = dataPoint.y + "";
        if (dataPoint.y != 0) {
            System.out.print(xy[0] + ",");
            System.out.println(xy[1]);

            try {
                // System.out.println("n:"+n);
                listX.add(pn, xy[0]);
                listY.add(pn, xy[1]);
            } catch (Exception e) {
                e.printStackTrace();
            }

            /*
             * System.out.println("N:" + n); System.out.println("ArrayList
             * listX:"+ listX.get(n)); System.out.println("ArrayList listY:"+
             * listY.get(n));
             */
        }
        ++pn;
        coefsValid = false;
    }

    /**
     * 返回回归线函数在x处的值. (Implementation of
     * Evaluatable.)
     * 
     * @param x
     *            the value of x
     * @return the value of the function at x
     */
    public double at(double x) {
        if (pn < 2)
            return Float.NaN;

        validateCoefficients();
        return a0 + a1 * x;
    }

    /**
     * Reset.
     */
    public void reset() {
        pn = 0;
        sumX = sumY = sumXX = sumXY = 0;
        coefsValid = false;
    }

    /**
     * Validate the coefficients. 计算方程系数 y=ax+b 中的a
     */
    private void validateCoefficients() {
        if (coefsValid)
            return;

        if (pn >= 2) {
            double xBar = sumX / pn;
            double yBar = sumY / pn;

            a1 = ((pn * sumXY - sumX * sumY) / (pn * sumXX - sumX * sumX));
            a0 = (yBar - a1 * xBar);
            a0 = round(a0, 4);
            a1 = round(a1, 4);
        } else {
            a0 = a1 = Float.NaN;
        }

        coefsValid = true;
    }

    /**
     * 返回误差
     */
    public double getR() {
        // 遍历这个list并计算分母
        for (int i = 0; i < pn - 1; i++) {
            double Yi = Double.parseDouble(listY.get(i));
            double Y = at(Double.parseDouble(listX.get(i).toString()));
            double deltaY = Yi - Y;
            double deltaY2 = deltaY * deltaY;

            // System.out.println("Yi:" + Yi);
            // System.out.println("Y:" + Y);
            // System.out.println("deltaY:" + deltaY);
            // System.out.println("deltaY2:" + deltaY2);

            sumDeltaY2 += deltaY2;
            // System.out.println("sumDeltaY2:" + sumDeltaY2);

        }

        sst = sumYY - (sumY * sumY) / pn;
        // System.out.println("sst:" + sst);
        E = 1 - sumDeltaY2 / sst;
        return round(E, 4);
    }

    // 用于实现精确的四舍五入
    public double round(double v, int scale) {

        if (scale < 0) {
            throw new IllegalArgumentException("比例必须是一个正整数或零");
        }

        BigDecimal b = new BigDecimal(Double.toString(v));
        BigDecimal one = new BigDecimal("1");
        return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).doubleValue();

    }

    public float round(float v, int scale) {

        if (scale < 0) {
            throw new IllegalArgumentException("比例必须是一个正整数或零");
        }

        BigDecimal b = new BigDecimal(Double.toString(v));
        BigDecimal one = new BigDecimal("1");
        return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).floatValue();

    }
}

2.3 线性回归测试类

public class LinearRegression {

    public static void main(String args[]) {
        RegressionLine line = new RegressionLine();
        // 两组数据,可用来测试和Excel对比
        double[] x = {0, 10, 50, 100};
        double[] y = {0.1, 1.03, 5.23, 9.906};
//        double[] x = {0.1, 0.2, 0.5, 1, 5, 10};
//        double[] y = {0.79, 1.84, 3.45, 5.08, 25.51, 50.36};

        for (int i = 0; i < x.length; i++) {
            line.addDataPoint(new DataPoint(x[i], y[i]));
        }

        printSums(line);
        printLine(line);
    }

    /**
     * 打印计算出来的总数
     * @param line 回归线
     */
    private static void printSums(RegressionLine line) {
        System.out.println("\n数据点个数 n = " + line.getDataPointCount());
        System.out.println("\nSum x  = " + line.getSumX());
        System.out.println("Sum y  = " + line.getSumY());
        System.out.println("Sum xx = " + line.getSumXX());
        System.out.println("Sum xy = " + line.getSumXY());
        System.out.println("Sum yy = " + line.getSumYY());

    }

    /**
     * 打印回归线函数
     * @param line 回归线
     *            
     */
    private static void printLine(RegressionLine line) {
        System.out.println("\n回归线公式:  y = " + line.getA1() + "x + " + line.getA0());
        System.out.println("误差:     R^2 = " + line.getR());
    }
}

测试类运行结果:

0.0,0.1
10.0,1.03
50.0,5.23
100.0,9.906

数据点个数 n = 4

Sum x = 160.0
Sum y = 16.266000000000002
Sum xx = 12600.0
Sum xy = 1262.4
Sum yy = 126.552636

回归线公式: y = 0.0987x + 0.1197
误差: R^2 = 0.9994

3. 总结

对于熟悉线性回归的人来说很好理解,后续我也还会持续使用,如果错漏欢迎指正。


心之所向,素履前往 ;生之逆旅,一苇以航 .