Change Dir

先知cd——热爱生活是一切艺术的开始

统计

留言簿(18)

积分与排名

“牛”们的博客

各个公司技术

我的链接

淘宝技术

阅读排行榜

评论排行榜

Commons Math学习笔记——聚类和回归

 

看其他篇章到目录选择。

聚类可以见我以前写过的聚类分析的文章。

回归是一个统计中非常重要的概念了。在Commons Math库中有一个regression的子包转么实现了线性回归的一些基本类型。在regression包中,有个基本接口就是MultipleLinearRegression,这个接口表达y=X*b+u这样的基本线性回归式。线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。简单看这个公式,y代表了一个n维的列向量(回归子),X代表了[n,k]大小的观测值矩阵(回归量),bk维的回归参数,u是一个n维的剩余误差。回归分析干什么用的?具体讲就是预测。我们在数据挖掘中定义,定性的分析叫做分类,而定量的分析叫做回归。回归就是根据已有的观察值去预测未来的一个定量的指标。记得前一段阿里云到学院来做技术交流,讲到阿里和淘宝通过数据分析对中国商品交易(还是具体什么贸易,忘记了,尴尬)的预测就是工程师做的一个简单的线性回归分析,模型虽然简单,但是后来与实际数据一比较,预测值与实际值的曲线基本吻合。

回到Commons Math中的实现,MultipleLinearRegression接口共有5个方法,double estimateRegressandVariance()返回回归子的方差,double[] estimateRegressionParameters()返回回归参数b(当然是一个k维的向量),double[] estimateRegressionParametersStandardErrors()返回回归参数的标准误差,double[][] estimateRegressionParametersVariance()返回回归参数b的方差,double[] estimateResiduals()返回剩余误差数组(n维向量)。AbstractMultipleLinearRegression类实现了这个接口,作为一个抽象类,它没有完全实现这些方法,仍旧作为抽象方法抛给具体的实现类去完成,典型的模板方法设计模式。继承这个类的方法有GLSMultipleLinearRegressionOLSMultipleLinearRegression,前者应该是广义最小二乘法,而后者是普通最小二乘法。

我们拿普通最小二乘法作为测试,看看它的source code如何实现:首先对于xy的表示,在AbstractMultipleLinearRegression内部,x被定义为一个RealMatrixy是一个RealVector,这些在第一章都有讲到。而OLSMultipleLinearRegression里用到定义了一个内部的QRDecomposition。通过void newSampleData(double[] y, double[][] x)方法来载入回归模型;接着就可以通过各种借口方法去计算这个回归的统计量了。算出bu之后,那么根据新的观测回归量x就可以预测y了。

 1/**
 2 * 
 3 */

 4package algorithm.math;
 5
 6import org.apache.commons.math.stat.regression.OLSMultipleLinearRegression;
 7import org.apache.commons.math.stat.regression.SimpleRegression;
 8
 9/**
10 * @author Jia Yu
11 * @date   2010-12-6
12 */

13public class RegressionTest {
14
15    /**
16     * @param args
17     */

18    public static void main(String[] args) {
19        // TODO Auto-generated method stub
20        regression();
21        System.out.println("-------------------------------------");
22        simple();
23    }

24
25    private static void simple() {
26        // TODO Auto-generated method stub
27        double[][] data = 0.10.2 }{338.8337.4 }{118.1118.2 }
28                {888.0884.6 }{9.210.1 }{228.1226.5 }{668.5666.3 }{998.5996.3 }
29                {449.1448.6 }{778.9777.0 }{559.2558.2 }{0.30.4 }{0.10.6 }{778.1775.5 }
30                {668.8666.9 }{339.3338.0 }{448.9447.5 }{10.811.6 }{557.7556.0 }
31                {228.3228.1 }{998.0995.8 }{888.8887.6 }{119.6120.2 }{0.30.3 }
32                {0.60.3 }{557.6556.8 }{339.3339.1 }{888.0887.2 }{998.5999.0 }
33                {778.9779.0 }{10.211.1 }{117.6118.3 }{228.9229.2 }{668.4669.1 }
34                {449.2448.9 }{0.20.5 }
35        }
;
36        SimpleRegression regression = new SimpleRegression();
37        for (int i = 0; i < data.length; i++{
38            regression.addData(data[i][1], data[i][0]);
39        }

40        System.out.println("slope is "+regression.getSlope());
41        System.out.println("slope std err is "+regression.getSlopeStdErr());
42        System.out.println("number of observations is "+regression.getN());
43        System.out.println("intercept is "+regression.getIntercept());
44        System.out.println("std err intercept is "+regression.getInterceptStdErr());
45        System.out.println("r-square is "+regression.getRSquare());
46        System.out.println("SSR is "+regression.getRegressionSumSquares());
47        System.out.println("MSE is "+regression.getMeanSquareError());
48        System.out.println("SSE is "+regression.getSumSquaredErrors());
49        System.out.println("predict(0) is "+regression.predict(0));
50        System.out.println("predict(1) is "+regression.predict(1));
51    }

52
53    private static void regression() {
54        // TODO Auto-generated method stub
55        double[] y;
56        double[][] x;
57        y = new double[]{11.012.013.014.015.016.0};
58        x = new double[6][];
59        x[0= new double[]{1.000000};
60        x[1= new double[]{1.02.00000};
61        x[2= new double[]{1.003.0000};
62        x[3= new double[]{1.0004.000};
63        x[4= new double[]{1.00005.00};
64        x[5= new double[]{1.000006.0};
65        
66        OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
67        regression.newSampleData(y, x);
68        
69        double[] betaHat = regression.estimateRegressionParameters();
70        System.out.println("Estimates the regression parameters b:");
71        print(betaHat);
72        double[] residuals = regression.estimateResiduals();
73        System.out.println("Estimates the residuals, ie u = y - X*b:");
74        print(residuals);
75        double vary = regression.estimateRegressandVariance();
76        System.out.println("Returns the variance of the regressand Var(y):");
77        System.out.println(vary);
78        double[] erros = regression.estimateRegressionParametersStandardErrors();
79        System.out.println("Returns the standard errors of the regression parameters:");
80        print(erros);
81        double[][] varb = regression.estimateRegressionParametersVariance();
82    }

83
84    private static void print(double[] v) {
85        // TODO Auto-generated method stub
86        for(int i=0;i<v.length;i++){
87            System.out.print(v[i]+ " ");
88        }

89        System.out.println();
90    }

91
92}

93

输出结果:
Estimates the regression parameters b:
11.000000000000004 0.4999999999999988 0.6666666666666657 0.7499999999999993 0.7999999999999993 0.8333333333333329
Estimates the residuals, ie u = y - X*b:
-3.552713678800501E-15 -1.7763568394002505E-15 0.0 0.0 0.0 0.0
Returns the variance of the regressand Var(y):
Infinity
Returns the standard errors of the regression parameters:
Infinity Infinity Infinity Infinity Infinity Infinity
-------------------------------------
slope is 1.0021168180204547
slope std err is 4.297968481840198E-4
number of observations is 36
intercept is -0.26232307377414243
std err intercept is 0.2328182342925303
r-square is 0.9999937458837121
SSR is 4255954.132323695
MSE is 0.7828646625720841
SSE is 26.61739852745086
predict(0) is -0.26232307377414243
predict(1) is 0.7397937442463123
 

Regression包中还有一个SimpleRegression类,可以通过直接addValue(x,y)来添加观测数据,再通过调用predict(x)方法来返回预测数据,使用非常方便,这里就不细讲了,给出了测试的代码,以后在具体的统计计算的代码中有用到会再说,感兴趣的也可以去自己阅读代码,只有短短的600行,非常好懂。

相关资料:

线性回归:http://zh.wikipedia.org/zh-cn/%E7%B7%9A%E6%80%A7%E5%9B%9E%E6%AD%B8#.E7.B7.9A.E6.80.A7.E8.BF.B4.E6.AD.B8.E7.9A.84.E6.87.89.E7.94.A8

最小二乘法:http://zh.wikipedia.org/zh-cn/%E6%9C%80%E5%B0%8F%E4%BA%8C%E4%B9%98%E6%B3%95

Commons math包:http://commons.apache.org/math/index.html

posted on 2011-01-01 18:35 changedi 阅读(6329) 评论(0)  编辑  收藏 所属分类: 数学


只有注册用户登录后才能发表评论。


网站导航: