001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math3.fitting; 018 019 import java.util.Arrays; 020 import java.util.Comparator; 021 import org.apache.commons.math3.analysis.function.Gaussian; 022 import org.apache.commons.math3.exception.NullArgumentException; 023 import org.apache.commons.math3.exception.NumberIsTooSmallException; 024 import org.apache.commons.math3.exception.OutOfRangeException; 025 import org.apache.commons.math3.exception.ZeroException; 026 import org.apache.commons.math3.exception.NotStrictlyPositiveException; 027 import org.apache.commons.math3.exception.util.LocalizedFormats; 028 import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer; 029 import org.apache.commons.math3.util.FastMath; 030 031 /** 032 * Fits points to a {@link 033 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} function. 034 * <p> 035 * Usage example: 036 * <pre> 037 * GaussianFitter fitter = new GaussianFitter( 038 * new LevenbergMarquardtOptimizer()); 039 * fitter.addObservedPoint(4.0254623, 531026.0); 040 * fitter.addObservedPoint(4.03128248, 984167.0); 041 * fitter.addObservedPoint(4.03839603, 1887233.0); 042 * fitter.addObservedPoint(4.04421621, 2687152.0); 043 * fitter.addObservedPoint(4.05132976, 3461228.0); 044 * fitter.addObservedPoint(4.05326982, 3580526.0); 045 * fitter.addObservedPoint(4.05779662, 3439750.0); 046 * fitter.addObservedPoint(4.0636168, 2877648.0); 047 * fitter.addObservedPoint(4.06943698, 2175960.0); 048 * fitter.addObservedPoint(4.07525716, 1447024.0); 049 * fitter.addObservedPoint(4.08237071, 717104.0); 050 * fitter.addObservedPoint(4.08366408, 620014.0); 051 * double[] parameters = fitter.fit(); 052 * </pre> 053 * 054 * @since 2.2 055 * @version $Id: GaussianFitter.java 1416643 2012-12-03 19:37:14Z tn $ 056 */ 057 public class GaussianFitter extends CurveFitter<Gaussian.Parametric> { 058 /** 059 * Constructs an instance using the specified optimizer. 060 * 061 * @param optimizer Optimizer to use for the fitting. 062 */ 063 public GaussianFitter(MultivariateVectorOptimizer optimizer) { 064 super(optimizer); 065 } 066 067 /** 068 * Fits a Gaussian function to the observed points. 069 * 070 * @param initialGuess First guess values in the following order: 071 * <ul> 072 * <li>Norm</li> 073 * <li>Mean</li> 074 * <li>Sigma</li> 075 * </ul> 076 * @return the parameters of the Gaussian function that best fits the 077 * observed points (in the same order as above). 078 * @since 3.0 079 */ 080 public double[] fit(double[] initialGuess) { 081 final Gaussian.Parametric f = new Gaussian.Parametric() { 082 @Override 083 public double value(double x, double ... p) { 084 double v = Double.POSITIVE_INFINITY; 085 try { 086 v = super.value(x, p); 087 } catch (NotStrictlyPositiveException e) { // NOPMD 088 // Do nothing. 089 } 090 return v; 091 } 092 093 @Override 094 public double[] gradient(double x, double ... p) { 095 double[] v = { Double.POSITIVE_INFINITY, 096 Double.POSITIVE_INFINITY, 097 Double.POSITIVE_INFINITY }; 098 try { 099 v = super.gradient(x, p); 100 } catch (NotStrictlyPositiveException e) { // NOPMD 101 // Do nothing. 102 } 103 return v; 104 } 105 }; 106 107 return fit(f, initialGuess); 108 } 109 110 /** 111 * Fits a Gaussian function to the observed points. 112 * 113 * @return the parameters of the Gaussian function that best fits the 114 * observed points (in the same order as above). 115 */ 116 public double[] fit() { 117 final double[] guess = (new ParameterGuesser(getObservations())).guess(); 118 return fit(guess); 119 } 120 121 /** 122 * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} 123 * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric} 124 * based on the specified observed points. 125 */ 126 public static class ParameterGuesser { 127 /** Normalization factor. */ 128 private final double norm; 129 /** Mean. */ 130 private final double mean; 131 /** Standard deviation. */ 132 private final double sigma; 133 134 /** 135 * Constructs instance with the specified observed points. 136 * 137 * @param observations Observed points from which to guess the 138 * parameters of the Gaussian. 139 * @throws NullArgumentException if {@code observations} is 140 * {@code null}. 141 * @throws NumberIsTooSmallException if there are less than 3 142 * observations. 143 */ 144 public ParameterGuesser(WeightedObservedPoint[] observations) { 145 if (observations == null) { 146 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 147 } 148 if (observations.length < 3) { 149 throw new NumberIsTooSmallException(observations.length, 3, true); 150 } 151 152 final WeightedObservedPoint[] sorted = sortObservations(observations); 153 final double[] params = basicGuess(sorted); 154 155 norm = params[0]; 156 mean = params[1]; 157 sigma = params[2]; 158 } 159 160 /** 161 * Gets an estimation of the parameters. 162 * 163 * @return the guessed parameters, in the following order: 164 * <ul> 165 * <li>Normalization factor</li> 166 * <li>Mean</li> 167 * <li>Standard deviation</li> 168 * </ul> 169 */ 170 public double[] guess() { 171 return new double[] { norm, mean, sigma }; 172 } 173 174 /** 175 * Sort the observations. 176 * 177 * @param unsorted Input observations. 178 * @return the input observations, sorted. 179 */ 180 private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) { 181 final WeightedObservedPoint[] observations = unsorted.clone(); 182 final Comparator<WeightedObservedPoint> cmp 183 = new Comparator<WeightedObservedPoint>() { 184 public int compare(WeightedObservedPoint p1, 185 WeightedObservedPoint p2) { 186 if (p1 == null && p2 == null) { 187 return 0; 188 } 189 if (p1 == null) { 190 return -1; 191 } 192 if (p2 == null) { 193 return 1; 194 } 195 if (p1.getX() < p2.getX()) { 196 return -1; 197 } 198 if (p1.getX() > p2.getX()) { 199 return 1; 200 } 201 if (p1.getY() < p2.getY()) { 202 return -1; 203 } 204 if (p1.getY() > p2.getY()) { 205 return 1; 206 } 207 if (p1.getWeight() < p2.getWeight()) { 208 return -1; 209 } 210 if (p1.getWeight() > p2.getWeight()) { 211 return 1; 212 } 213 return 0; 214 } 215 }; 216 217 Arrays.sort(observations, cmp); 218 return observations; 219 } 220 221 /** 222 * Guesses the parameters based on the specified observed points. 223 * 224 * @param points Observed points, sorted. 225 * @return the guessed parameters (normalization factor, mean and 226 * sigma). 227 */ 228 private double[] basicGuess(WeightedObservedPoint[] points) { 229 final int maxYIdx = findMaxY(points); 230 final double n = points[maxYIdx].getY(); 231 final double m = points[maxYIdx].getX(); 232 233 double fwhmApprox; 234 try { 235 final double halfY = n + ((m - n) / 2); 236 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY); 237 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY); 238 fwhmApprox = fwhmX2 - fwhmX1; 239 } catch (OutOfRangeException e) { 240 // TODO: Exceptions should not be used for flow control. 241 fwhmApprox = points[points.length - 1].getX() - points[0].getX(); 242 } 243 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2))); 244 245 return new double[] { n, m, s }; 246 } 247 248 /** 249 * Finds index of point in specified points with the largest Y. 250 * 251 * @param points Points to search. 252 * @return the index in specified points array. 253 */ 254 private int findMaxY(WeightedObservedPoint[] points) { 255 int maxYIdx = 0; 256 for (int i = 1; i < points.length; i++) { 257 if (points[i].getY() > points[maxYIdx].getY()) { 258 maxYIdx = i; 259 } 260 } 261 return maxYIdx; 262 } 263 264 /** 265 * Interpolates using the specified points to determine X at the 266 * specified Y. 267 * 268 * @param points Points to use for interpolation. 269 * @param startIdx Index within points from which to start the search for 270 * interpolation bounds points. 271 * @param idxStep Index step for searching interpolation bounds points. 272 * @param y Y value for which X should be determined. 273 * @return the value of X for the specified Y. 274 * @throws ZeroException if {@code idxStep} is 0. 275 * @throws OutOfRangeException if specified {@code y} is not within the 276 * range of the specified {@code points}. 277 */ 278 private double interpolateXAtY(WeightedObservedPoint[] points, 279 int startIdx, 280 int idxStep, 281 double y) 282 throws OutOfRangeException { 283 if (idxStep == 0) { 284 throw new ZeroException(); 285 } 286 final WeightedObservedPoint[] twoPoints 287 = getInterpolationPointsForY(points, startIdx, idxStep, y); 288 final WeightedObservedPoint p1 = twoPoints[0]; 289 final WeightedObservedPoint p2 = twoPoints[1]; 290 if (p1.getY() == y) { 291 return p1.getX(); 292 } 293 if (p2.getY() == y) { 294 return p2.getX(); 295 } 296 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / 297 (p2.getY() - p1.getY())); 298 } 299 300 /** 301 * Gets the two bounding interpolation points from the specified points 302 * suitable for determining X at the specified Y. 303 * 304 * @param points Points to use for interpolation. 305 * @param startIdx Index within points from which to start search for 306 * interpolation bounds points. 307 * @param idxStep Index step for search for interpolation bounds points. 308 * @param y Y value for which X should be determined. 309 * @return the array containing two points suitable for determining X at 310 * the specified Y. 311 * @throws ZeroException if {@code idxStep} is 0. 312 * @throws OutOfRangeException if specified {@code y} is not within the 313 * range of the specified {@code points}. 314 */ 315 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, 316 int startIdx, 317 int idxStep, 318 double y) 319 throws OutOfRangeException { 320 if (idxStep == 0) { 321 throw new ZeroException(); 322 } 323 for (int i = startIdx; 324 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; 325 i += idxStep) { 326 final WeightedObservedPoint p1 = points[i]; 327 final WeightedObservedPoint p2 = points[i + idxStep]; 328 if (isBetween(y, p1.getY(), p2.getY())) { 329 if (idxStep < 0) { 330 return new WeightedObservedPoint[] { p2, p1 }; 331 } else { 332 return new WeightedObservedPoint[] { p1, p2 }; 333 } 334 } 335 } 336 337 // Boundaries are replaced by dummy values because the raised 338 // exception is caught and the message never displayed. 339 // TODO: Exceptions should not be used for flow control. 340 throw new OutOfRangeException(y, 341 Double.NEGATIVE_INFINITY, 342 Double.POSITIVE_INFINITY); 343 } 344 345 /** 346 * Determines whether a value is between two other values. 347 * 348 * @param value Value to test whether it is between {@code boundary1} 349 * and {@code boundary2}. 350 * @param boundary1 One end of the range. 351 * @param boundary2 Other end of the range. 352 * @return {@code true} if {@code value} is between {@code boundary1} and 353 * {@code boundary2} (inclusive), {@code false} otherwise. 354 */ 355 private boolean isBetween(double value, 356 double boundary1, 357 double boundary2) { 358 return (value >= boundary1 && value <= boundary2) || 359 (value >= boundary2 && value <= boundary1); 360 } 361 } 362 }