001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math3.distribution; 018 019 import org.apache.commons.math3.exception.DimensionMismatchException; 020 import org.apache.commons.math3.linear.Array2DRowRealMatrix; 021 import org.apache.commons.math3.linear.EigenDecomposition; 022 import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException; 023 import org.apache.commons.math3.linear.RealMatrix; 024 import org.apache.commons.math3.linear.SingularMatrixException; 025 import org.apache.commons.math3.random.RandomGenerator; 026 import org.apache.commons.math3.random.Well19937c; 027 import org.apache.commons.math3.util.FastMath; 028 import org.apache.commons.math3.util.MathArrays; 029 030 /** 031 * Implementation of the multivariate normal (Gaussian) distribution. 032 * 033 * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution"> 034 * Multivariate normal distribution (Wikipedia)</a> 035 * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html"> 036 * Multivariate normal distribution (MathWorld)</a> 037 * 038 * @version $Id: MultivariateNormalDistribution.java 1416643 2012-12-03 19:37:14Z tn $ 039 * @since 3.1 040 */ 041 public class MultivariateNormalDistribution 042 extends AbstractMultivariateRealDistribution { 043 /** Vector of means. */ 044 private final double[] means; 045 /** Covariance matrix. */ 046 private final RealMatrix covarianceMatrix; 047 /** The matrix inverse of the covariance matrix. */ 048 private final RealMatrix covarianceMatrixInverse; 049 /** The determinant of the covariance matrix. */ 050 private final double covarianceMatrixDeterminant; 051 /** Matrix used in computation of samples. */ 052 private final RealMatrix samplingMatrix; 053 054 /** 055 * Creates a multivariate normal distribution with the given mean vector and 056 * covariance matrix. 057 * <br/> 058 * The number of dimensions is equal to the length of the mean vector 059 * and to the number of rows and columns of the covariance matrix. 060 * It is frequently written as "p" in formulae. 061 * 062 * @param means Vector of means. 063 * @param covariances Covariance matrix. 064 * @throws DimensionMismatchException if the arrays length are 065 * inconsistent. 066 * @throws SingularMatrixException if the eigenvalue decomposition cannot 067 * be performed on the provided covariance matrix. 068 * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is 069 * negative. 070 */ 071 public MultivariateNormalDistribution(final double[] means, 072 final double[][] covariances) 073 throws SingularMatrixException, 074 DimensionMismatchException, 075 NonPositiveDefiniteMatrixException { 076 this(new Well19937c(), means, covariances); 077 } 078 079 /** 080 * Creates a multivariate normal distribution with the given mean vector and 081 * covariance matrix. 082 * <br/> 083 * The number of dimensions is equal to the length of the mean vector 084 * and to the number of rows and columns of the covariance matrix. 085 * It is frequently written as "p" in formulae. 086 * 087 * @param rng Random Number Generator. 088 * @param means Vector of means. 089 * @param covariances Covariance matrix. 090 * @throws DimensionMismatchException if the arrays length are 091 * inconsistent. 092 * @throws SingularMatrixException if the eigenvalue decomposition cannot 093 * be performed on the provided covariance matrix. 094 * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is 095 * negative. 096 */ 097 public MultivariateNormalDistribution(RandomGenerator rng, 098 final double[] means, 099 final double[][] covariances) 100 throws SingularMatrixException, 101 DimensionMismatchException, 102 NonPositiveDefiniteMatrixException { 103 super(rng, means.length); 104 105 final int dim = means.length; 106 107 if (covariances.length != dim) { 108 throw new DimensionMismatchException(covariances.length, dim); 109 } 110 111 for (int i = 0; i < dim; i++) { 112 if (dim != covariances[i].length) { 113 throw new DimensionMismatchException(covariances[i].length, dim); 114 } 115 } 116 117 this.means = MathArrays.copyOf(means); 118 119 covarianceMatrix = new Array2DRowRealMatrix(covariances); 120 121 // Covariance matrix eigen decomposition. 122 final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix); 123 124 // Compute and store the inverse. 125 covarianceMatrixInverse = covMatDec.getSolver().getInverse(); 126 // Compute and store the determinant. 127 covarianceMatrixDeterminant = covMatDec.getDeterminant(); 128 129 // Eigenvalues of the covariance matrix. 130 final double[] covMatEigenvalues = covMatDec.getRealEigenvalues(); 131 132 for (int i = 0; i < covMatEigenvalues.length; i++) { 133 if (covMatEigenvalues[i] < 0) { 134 throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0); 135 } 136 } 137 138 // Matrix where each column is an eigenvector of the covariance matrix. 139 final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim); 140 for (int v = 0; v < dim; v++) { 141 final double[] evec = covMatDec.getEigenvector(v).toArray(); 142 covMatEigenvectors.setColumn(v, evec); 143 } 144 145 final RealMatrix tmpMatrix = covMatEigenvectors.transpose(); 146 147 // Scale each eigenvector by the square root of its eigenvalue. 148 for (int row = 0; row < dim; row++) { 149 final double factor = FastMath.sqrt(covMatEigenvalues[row]); 150 for (int col = 0; col < dim; col++) { 151 tmpMatrix.multiplyEntry(row, col, factor); 152 } 153 } 154 155 samplingMatrix = covMatEigenvectors.multiply(tmpMatrix); 156 } 157 158 /** 159 * Gets the mean vector. 160 * 161 * @return the mean vector. 162 */ 163 public double[] getMeans() { 164 return MathArrays.copyOf(means); 165 } 166 167 /** 168 * Gets the covariance matrix. 169 * 170 * @return the covariance matrix. 171 */ 172 public RealMatrix getCovariances() { 173 return covarianceMatrix.copy(); 174 } 175 176 /** {@inheritDoc} */ 177 public double density(final double[] vals) throws DimensionMismatchException { 178 final int dim = getDimension(); 179 if (vals.length != dim) { 180 throw new DimensionMismatchException(vals.length, dim); 181 } 182 183 return FastMath.pow(2 * FastMath.PI, -dim / 2) * 184 FastMath.pow(covarianceMatrixDeterminant, -0.5) * 185 getExponentTerm(vals); 186 } 187 188 /** 189 * Gets the square root of each element on the diagonal of the covariance 190 * matrix. 191 * 192 * @return the standard deviations. 193 */ 194 public double[] getStandardDeviations() { 195 final int dim = getDimension(); 196 final double[] std = new double[dim]; 197 final double[][] s = covarianceMatrix.getData(); 198 for (int i = 0; i < dim; i++) { 199 std[i] = FastMath.sqrt(s[i][i]); 200 } 201 return std; 202 } 203 204 /** {@inheritDoc} */ 205 public double[] sample() { 206 final int dim = getDimension(); 207 final double[] normalVals = new double[dim]; 208 209 for (int i = 0; i < dim; i++) { 210 normalVals[i] = random.nextGaussian(); 211 } 212 213 final double[] vals = samplingMatrix.operate(normalVals); 214 215 for (int i = 0; i < dim; i++) { 216 vals[i] += means[i]; 217 } 218 219 return vals; 220 } 221 222 /** 223 * Computes the term used in the exponent (see definition of the distribution). 224 * 225 * @param values Values at which to compute density. 226 * @return the multiplication factor of density calculations. 227 */ 228 private double getExponentTerm(final double[] values) { 229 final double[] centered = new double[values.length]; 230 for (int i = 0; i < centered.length; i++) { 231 centered[i] = values[i] - getMeans()[i]; 232 } 233 final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered); 234 double sum = 0; 235 for (int i = 0; i < preMultiplied.length; i++) { 236 sum += preMultiplied[i] * centered[i]; 237 } 238 return FastMath.exp(-0.5 * sum); 239 } 240 }