001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.distribution;
018    
019    import org.apache.commons.math3.exception.DimensionMismatchException;
020    import org.apache.commons.math3.linear.Array2DRowRealMatrix;
021    import org.apache.commons.math3.linear.EigenDecomposition;
022    import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
023    import org.apache.commons.math3.linear.RealMatrix;
024    import org.apache.commons.math3.linear.SingularMatrixException;
025    import org.apache.commons.math3.random.RandomGenerator;
026    import org.apache.commons.math3.random.Well19937c;
027    import org.apache.commons.math3.util.FastMath;
028    import org.apache.commons.math3.util.MathArrays;
029    
030    /**
031     * Implementation of the multivariate normal (Gaussian) distribution.
032     *
033     * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">
034     * Multivariate normal distribution (Wikipedia)</a>
035     * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html">
036     * Multivariate normal distribution (MathWorld)</a>
037     *
038     * @version $Id: MultivariateNormalDistribution.java 1416643 2012-12-03 19:37:14Z tn $
039     * @since 3.1
040     */
041    public class MultivariateNormalDistribution
042        extends AbstractMultivariateRealDistribution {
043        /** Vector of means. */
044        private final double[] means;
045        /** Covariance matrix. */
046        private final RealMatrix covarianceMatrix;
047        /** The matrix inverse of the covariance matrix. */
048        private final RealMatrix covarianceMatrixInverse;
049        /** The determinant of the covariance matrix. */
050        private final double covarianceMatrixDeterminant;
051        /** Matrix used in computation of samples. */
052        private final RealMatrix samplingMatrix;
053    
054        /**
055         * Creates a multivariate normal distribution with the given mean vector and
056         * covariance matrix.
057         * <br/>
058         * The number of dimensions is equal to the length of the mean vector
059         * and to the number of rows and columns of the covariance matrix.
060         * It is frequently written as "p" in formulae.
061         *
062         * @param means Vector of means.
063         * @param covariances Covariance matrix.
064         * @throws DimensionMismatchException if the arrays length are
065         * inconsistent.
066         * @throws SingularMatrixException if the eigenvalue decomposition cannot
067         * be performed on the provided covariance matrix.
068         * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
069         * negative.
070         */
071        public MultivariateNormalDistribution(final double[] means,
072                                              final double[][] covariances)
073            throws SingularMatrixException,
074                   DimensionMismatchException,
075                   NonPositiveDefiniteMatrixException {
076            this(new Well19937c(), means, covariances);
077        }
078    
079        /**
080         * Creates a multivariate normal distribution with the given mean vector and
081         * covariance matrix.
082         * <br/>
083         * The number of dimensions is equal to the length of the mean vector
084         * and to the number of rows and columns of the covariance matrix.
085         * It is frequently written as "p" in formulae.
086         *
087         * @param rng Random Number Generator.
088         * @param means Vector of means.
089         * @param covariances Covariance matrix.
090         * @throws DimensionMismatchException if the arrays length are
091         * inconsistent.
092         * @throws SingularMatrixException if the eigenvalue decomposition cannot
093         * be performed on the provided covariance matrix.
094         * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
095         * negative.
096         */
097        public MultivariateNormalDistribution(RandomGenerator rng,
098                                              final double[] means,
099                                              final double[][] covariances)
100                throws SingularMatrixException,
101                       DimensionMismatchException,
102                       NonPositiveDefiniteMatrixException {
103            super(rng, means.length);
104    
105            final int dim = means.length;
106    
107            if (covariances.length != dim) {
108                throw new DimensionMismatchException(covariances.length, dim);
109            }
110    
111            for (int i = 0; i < dim; i++) {
112                if (dim != covariances[i].length) {
113                    throw new DimensionMismatchException(covariances[i].length, dim);
114                }
115            }
116    
117            this.means = MathArrays.copyOf(means);
118    
119            covarianceMatrix = new Array2DRowRealMatrix(covariances);
120    
121            // Covariance matrix eigen decomposition.
122            final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix);
123    
124            // Compute and store the inverse.
125            covarianceMatrixInverse = covMatDec.getSolver().getInverse();
126            // Compute and store the determinant.
127            covarianceMatrixDeterminant = covMatDec.getDeterminant();
128    
129            // Eigenvalues of the covariance matrix.
130            final double[] covMatEigenvalues = covMatDec.getRealEigenvalues();
131    
132            for (int i = 0; i < covMatEigenvalues.length; i++) {
133                if (covMatEigenvalues[i] < 0) {
134                    throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0);
135                }
136            }
137    
138            // Matrix where each column is an eigenvector of the covariance matrix.
139            final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
140            for (int v = 0; v < dim; v++) {
141                final double[] evec = covMatDec.getEigenvector(v).toArray();
142                covMatEigenvectors.setColumn(v, evec);
143            }
144    
145            final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
146    
147            // Scale each eigenvector by the square root of its eigenvalue.
148            for (int row = 0; row < dim; row++) {
149                final double factor = FastMath.sqrt(covMatEigenvalues[row]);
150                for (int col = 0; col < dim; col++) {
151                    tmpMatrix.multiplyEntry(row, col, factor);
152                }
153            }
154    
155            samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
156        }
157    
158        /**
159         * Gets the mean vector.
160         *
161         * @return the mean vector.
162         */
163        public double[] getMeans() {
164            return MathArrays.copyOf(means);
165        }
166    
167        /**
168         * Gets the covariance matrix.
169         *
170         * @return the covariance matrix.
171         */
172        public RealMatrix getCovariances() {
173            return covarianceMatrix.copy();
174        }
175    
176        /** {@inheritDoc} */
177        public double density(final double[] vals) throws DimensionMismatchException {
178            final int dim = getDimension();
179            if (vals.length != dim) {
180                throw new DimensionMismatchException(vals.length, dim);
181            }
182    
183            return FastMath.pow(2 * FastMath.PI, -dim / 2) *
184                FastMath.pow(covarianceMatrixDeterminant, -0.5) *
185                getExponentTerm(vals);
186        }
187    
188        /**
189         * Gets the square root of each element on the diagonal of the covariance
190         * matrix.
191         *
192         * @return the standard deviations.
193         */
194        public double[] getStandardDeviations() {
195            final int dim = getDimension();
196            final double[] std = new double[dim];
197            final double[][] s = covarianceMatrix.getData();
198            for (int i = 0; i < dim; i++) {
199                std[i] = FastMath.sqrt(s[i][i]);
200            }
201            return std;
202        }
203    
204        /** {@inheritDoc} */
205        public double[] sample() {
206            final int dim = getDimension();
207            final double[] normalVals = new double[dim];
208    
209            for (int i = 0; i < dim; i++) {
210                normalVals[i] = random.nextGaussian();
211            }
212    
213            final double[] vals = samplingMatrix.operate(normalVals);
214    
215            for (int i = 0; i < dim; i++) {
216                vals[i] += means[i];
217            }
218    
219            return vals;
220        }
221    
222        /**
223         * Computes the term used in the exponent (see definition of the distribution).
224         *
225         * @param values Values at which to compute density.
226         * @return the multiplication factor of density calculations.
227         */
228        private double getExponentTerm(final double[] values) {
229            final double[] centered = new double[values.length];
230            for (int i = 0; i < centered.length; i++) {
231                centered[i] = values[i] - getMeans()[i];
232            }
233            final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
234            double sum = 0;
235            for (int i = 0; i < preMultiplied.length; i++) {
236                sum += preMultiplied[i] * centered[i];
237            }
238            return FastMath.exp(-0.5 * sum);
239        }
240    }