001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.stat.regression;
018    
019    import org.apache.commons.math3.exception.MathIllegalArgumentException;
020    import org.apache.commons.math3.linear.Array2DRowRealMatrix;
021    import org.apache.commons.math3.linear.LUDecomposition;
022    import org.apache.commons.math3.linear.QRDecomposition;
023    import org.apache.commons.math3.linear.RealMatrix;
024    import org.apache.commons.math3.linear.RealVector;
025    import org.apache.commons.math3.stat.StatUtils;
026    import org.apache.commons.math3.stat.descriptive.moment.SecondMoment;
027    
028    /**
029     * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
030     * multiple linear regression model.</p>
031     *
032     * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
033     * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
034     *
035     * <p>To solve the normal equations, this implementation uses QR decomposition
036     * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
037     * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
038     * has rows corresponding to sample observations and columns corresponding to independent
039     * variables.  When the model is estimated using an intercept term (i.e. when
040     * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
041     * matrix includes an initial column identically equal to 1.  We solve the normal equations
042     * as follows:
043     * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
044     * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
045     * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
046     * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
047     * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
048     * R b = Q<sup>T</sup> y </code></pre></p>
049     *
050     * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
051     *
052     * @version $Id: OLSMultipleLinearRegression.java 1416643 2012-12-03 19:37:14Z tn $
053     * @since 2.0
054     */
055    public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
056    
057        /** Cached QR decomposition of X matrix */
058        private QRDecomposition qr = null;
059    
060        /**
061         * Loads model x and y sample data, overriding any previous sample.
062         *
063         * Computes and caches QR decomposition of the X matrix.
064         * @param y the [n,1] array representing the y sample
065         * @param x the [n,k] array representing the x sample
066         * @throws MathIllegalArgumentException if the x and y array data are not
067         *             compatible for the regression
068         */
069        public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
070            validateSampleData(x, y);
071            newYSampleData(y);
072            newXSampleData(x);
073        }
074    
075        /**
076         * {@inheritDoc}
077         * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
078         */
079        @Override
080        public void newSampleData(double[] data, int nobs, int nvars) {
081            super.newSampleData(data, nobs, nvars);
082            qr = new QRDecomposition(getX());
083        }
084    
085        /**
086         * <p>Compute the "hat" matrix.
087         * </p>
088         * <p>The hat matrix is defined in terms of the design matrix X
089         *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
090         * </p>
091         * <p>The implementation here uses the QR decomposition to compute the
092         * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
093         * p-dimensional identity matrix augmented by 0's.  This computational
094         * formula is from "The Hat Matrix in Regression and ANOVA",
095         * David C. Hoaglin and Roy E. Welsch,
096         * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
097         * </p>
098         * <p>Data for the model must have been successfully loaded using one of
099         * the {@code newSampleData} methods before invoking this method; otherwise
100         * a {@code NullPointerException} will be thrown.</p>
101         *
102         * @return the hat matrix
103         */
104        public RealMatrix calculateHat() {
105            // Create augmented identity matrix
106            RealMatrix Q = qr.getQ();
107            final int p = qr.getR().getColumnDimension();
108            final int n = Q.getColumnDimension();
109            // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
110            Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
111            double[][] augIData = augI.getDataRef();
112            for (int i = 0; i < n; i++) {
113                for (int j =0; j < n; j++) {
114                    if (i == j && i < p) {
115                        augIData[i][j] = 1d;
116                    } else {
117                        augIData[i][j] = 0d;
118                    }
119                }
120            }
121    
122            // Compute and return Hat matrix
123            // No DME advertised - args valid if we get here
124            return Q.multiply(augI).multiply(Q.transpose());
125        }
126    
127        /**
128         * <p>Returns the sum of squared deviations of Y from its mean.</p>
129         *
130         * <p>If the model has no intercept term, <code>0</code> is used for the
131         * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
132         *
133         * <p>The value returned by this method is the SSTO value used in
134         * the {@link #calculateRSquared() R-squared} computation.</p>
135         *
136         * @return SSTO - the total sum of squares
137         * @throws MathIllegalArgumentException if the sample has not been set or does
138         * not contain at least 3 observations
139         * @see #isNoIntercept()
140         * @since 2.2
141         */
142        public double calculateTotalSumOfSquares() throws MathIllegalArgumentException {
143            if (isNoIntercept()) {
144                return StatUtils.sumSq(getY().toArray());
145            } else {
146                return new SecondMoment().evaluate(getY().toArray());
147            }
148        }
149    
150        /**
151         * Returns the sum of squared residuals.
152         *
153         * @return residual sum of squares
154         * @since 2.2
155         */
156        public double calculateResidualSumOfSquares() {
157            final RealVector residuals = calculateResiduals();
158            // No advertised DME, args are valid
159            return residuals.dotProduct(residuals);
160        }
161    
162        /**
163         * Returns the R-Squared statistic, defined by the formula <pre>
164         * R<sup>2</sup> = 1 - SSR / SSTO
165         * </pre>
166         * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
167         * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
168         *
169         * @return R-square statistic
170         * @throws MathIllegalArgumentException if the sample has not been set or does
171         * not contain at least 3 observations
172         * @since 2.2
173         */
174        public double calculateRSquared() throws MathIllegalArgumentException {
175            return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
176        }
177    
178        /**
179         * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
180         * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
181         * </pre>
182         * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
183         * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
184         * of observations and p is the number of parameters estimated (including the intercept).</p>
185         *
186         * <p>If the regression is estimated without an intercept term, what is returned is <pre>
187         * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
188         * </pre></p>
189         *
190         * @return adjusted R-Squared statistic
191         * @throws MathIllegalArgumentException if the sample has not been set or does
192         * not contain at least 3 observations
193         * @see #isNoIntercept()
194         * @since 2.2
195         */
196        public double calculateAdjustedRSquared() throws MathIllegalArgumentException {
197            final double n = getX().getRowDimension();
198            if (isNoIntercept()) {
199                return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
200            } else {
201                return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
202                    (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
203            }
204        }
205    
206        /**
207         * {@inheritDoc}
208         * <p>This implementation computes and caches the QR decomposition of the X matrix
209         * once it is successfully loaded.</p>
210         */
211        @Override
212        protected void newXSampleData(double[][] x) {
213            super.newXSampleData(x);
214            qr = new QRDecomposition(getX());
215        }
216    
217        /**
218         * Calculates the regression coefficients using OLS.
219         *
220         * <p>Data for the model must have been successfully loaded using one of
221         * the {@code newSampleData} methods before invoking this method; otherwise
222         * a {@code NullPointerException} will be thrown.</p>
223         *
224         * @return beta
225         */
226        @Override
227        protected RealVector calculateBeta() {
228            return qr.getSolver().solve(getY());
229        }
230    
231        /**
232         * <p>Calculates the variance-covariance matrix of the regression parameters.
233         * </p>
234         * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
235         * </p>
236         * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
237         * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
238         * R included, where p = the length of the beta vector.</p>
239         *
240         * <p>Data for the model must have been successfully loaded using one of
241         * the {@code newSampleData} methods before invoking this method; otherwise
242         * a {@code NullPointerException} will be thrown.</p>
243         *
244         * @return The beta variance-covariance matrix
245         */
246        @Override
247        protected RealMatrix calculateBetaVariance() {
248            int p = getX().getColumnDimension();
249            RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
250            RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
251            return Rinv.multiply(Rinv.transpose());
252        }
253    
254    }