001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
018    
019    import java.util.Comparator;
020    
021    import org.apache.commons.math3.optim.PointValuePair;
022    import org.apache.commons.math3.analysis.MultivariateFunction;
023    
024    /**
025     * This class implements the Nelder-Mead simplex algorithm.
026     *
027     * @version $Id: NelderMeadSimplex.java 1364392 2012-07-22 18:27:12Z tn $
028     * @since 3.0
029     */
030    public class NelderMeadSimplex extends AbstractSimplex {
031        /** Default value for {@link #rho}: {@value}. */
032        private static final double DEFAULT_RHO = 1;
033        /** Default value for {@link #khi}: {@value}. */
034        private static final double DEFAULT_KHI = 2;
035        /** Default value for {@link #gamma}: {@value}. */
036        private static final double DEFAULT_GAMMA = 0.5;
037        /** Default value for {@link #sigma}: {@value}. */
038        private static final double DEFAULT_SIGMA = 0.5;
039        /** Reflection coefficient. */
040        private final double rho;
041        /** Expansion coefficient. */
042        private final double khi;
043        /** Contraction coefficient. */
044        private final double gamma;
045        /** Shrinkage coefficient. */
046        private final double sigma;
047    
048        /**
049         * Build a Nelder-Mead simplex with default coefficients.
050         * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
051         * for both gamma and sigma.
052         *
053         * @param n Dimension of the simplex.
054         */
055        public NelderMeadSimplex(final int n) {
056            this(n, 1d);
057        }
058    
059        /**
060         * Build a Nelder-Mead simplex with default coefficients.
061         * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
062         * for both gamma and sigma.
063         *
064         * @param n Dimension of the simplex.
065         * @param sideLength Length of the sides of the default (hypercube)
066         * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
067         */
068        public NelderMeadSimplex(final int n, double sideLength) {
069            this(n, sideLength,
070                 DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
071        }
072    
073        /**
074         * Build a Nelder-Mead simplex with specified coefficients.
075         *
076         * @param n Dimension of the simplex. See
077         * {@link AbstractSimplex#AbstractSimplex(int,double)}.
078         * @param sideLength Length of the sides of the default (hypercube)
079         * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
080         * @param rho Reflection coefficient.
081         * @param khi Expansion coefficient.
082         * @param gamma Contraction coefficient.
083         * @param sigma Shrinkage coefficient.
084         */
085        public NelderMeadSimplex(final int n, double sideLength,
086                                 final double rho, final double khi,
087                                 final double gamma, final double sigma) {
088            super(n, sideLength);
089    
090            this.rho = rho;
091            this.khi = khi;
092            this.gamma = gamma;
093            this.sigma = sigma;
094        }
095    
096        /**
097         * Build a Nelder-Mead simplex with specified coefficients.
098         *
099         * @param n Dimension of the simplex. See
100         * {@link AbstractSimplex#AbstractSimplex(int)}.
101         * @param rho Reflection coefficient.
102         * @param khi Expansion coefficient.
103         * @param gamma Contraction coefficient.
104         * @param sigma Shrinkage coefficient.
105         */
106        public NelderMeadSimplex(final int n,
107                                 final double rho, final double khi,
108                                 final double gamma, final double sigma) {
109            this(n, 1d, rho, khi, gamma, sigma);
110        }
111    
112        /**
113         * Build a Nelder-Mead simplex with default coefficients.
114         * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
115         * for both gamma and sigma.
116         *
117         * @param steps Steps along the canonical axes representing box edges.
118         * They may be negative but not zero. See
119         */
120        public NelderMeadSimplex(final double[] steps) {
121            this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
122        }
123    
124        /**
125         * Build a Nelder-Mead simplex with specified coefficients.
126         *
127         * @param steps Steps along the canonical axes representing box edges.
128         * They may be negative but not zero. See
129         * {@link AbstractSimplex#AbstractSimplex(double[])}.
130         * @param rho Reflection coefficient.
131         * @param khi Expansion coefficient.
132         * @param gamma Contraction coefficient.
133         * @param sigma Shrinkage coefficient.
134         * @throws IllegalArgumentException if one of the steps is zero.
135         */
136        public NelderMeadSimplex(final double[] steps,
137                                 final double rho, final double khi,
138                                 final double gamma, final double sigma) {
139            super(steps);
140    
141            this.rho = rho;
142            this.khi = khi;
143            this.gamma = gamma;
144            this.sigma = sigma;
145        }
146    
147        /**
148         * Build a Nelder-Mead simplex with default coefficients.
149         * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
150         * for both gamma and sigma.
151         *
152         * @param referenceSimplex Reference simplex. See
153         * {@link AbstractSimplex#AbstractSimplex(double[][])}.
154         */
155        public NelderMeadSimplex(final double[][] referenceSimplex) {
156            this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
157        }
158    
159        /**
160         * Build a Nelder-Mead simplex with specified coefficients.
161         *
162         * @param referenceSimplex Reference simplex. See
163         * {@link AbstractSimplex#AbstractSimplex(double[][])}.
164         * @param rho Reflection coefficient.
165         * @param khi Expansion coefficient.
166         * @param gamma Contraction coefficient.
167         * @param sigma Shrinkage coefficient.
168         * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
169         * if the reference simplex does not contain at least one point.
170         * @throws org.apache.commons.math3.exception.DimensionMismatchException
171         * if there is a dimension mismatch in the reference simplex.
172         */
173        public NelderMeadSimplex(final double[][] referenceSimplex,
174                                 final double rho, final double khi,
175                                 final double gamma, final double sigma) {
176            super(referenceSimplex);
177    
178            this.rho = rho;
179            this.khi = khi;
180            this.gamma = gamma;
181            this.sigma = sigma;
182        }
183    
184        /** {@inheritDoc} */
185        @Override
186        public void iterate(final MultivariateFunction evaluationFunction,
187                            final Comparator<PointValuePair> comparator) {
188            // The simplex has n + 1 points if dimension is n.
189            final int n = getDimension();
190    
191            // Interesting values.
192            final PointValuePair best = getPoint(0);
193            final PointValuePair secondBest = getPoint(n - 1);
194            final PointValuePair worst = getPoint(n);
195            final double[] xWorst = worst.getPointRef();
196    
197            // Compute the centroid of the best vertices (dismissing the worst
198            // point at index n).
199            final double[] centroid = new double[n];
200            for (int i = 0; i < n; i++) {
201                final double[] x = getPoint(i).getPointRef();
202                for (int j = 0; j < n; j++) {
203                    centroid[j] += x[j];
204                }
205            }
206            final double scaling = 1.0 / n;
207            for (int j = 0; j < n; j++) {
208                centroid[j] *= scaling;
209            }
210    
211            // compute the reflection point
212            final double[] xR = new double[n];
213            for (int j = 0; j < n; j++) {
214                xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
215            }
216            final PointValuePair reflected
217                = new PointValuePair(xR, evaluationFunction.value(xR), false);
218    
219            if (comparator.compare(best, reflected) <= 0 &&
220                comparator.compare(reflected, secondBest) < 0) {
221                // Accept the reflected point.
222                replaceWorstPoint(reflected, comparator);
223            } else if (comparator.compare(reflected, best) < 0) {
224                // Compute the expansion point.
225                final double[] xE = new double[n];
226                for (int j = 0; j < n; j++) {
227                    xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
228                }
229                final PointValuePair expanded
230                    = new PointValuePair(xE, evaluationFunction.value(xE), false);
231    
232                if (comparator.compare(expanded, reflected) < 0) {
233                    // Accept the expansion point.
234                    replaceWorstPoint(expanded, comparator);
235                } else {
236                    // Accept the reflected point.
237                    replaceWorstPoint(reflected, comparator);
238                }
239            } else {
240                if (comparator.compare(reflected, worst) < 0) {
241                    // Perform an outside contraction.
242                    final double[] xC = new double[n];
243                    for (int j = 0; j < n; j++) {
244                        xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
245                    }
246                    final PointValuePair outContracted
247                        = new PointValuePair(xC, evaluationFunction.value(xC), false);
248                    if (comparator.compare(outContracted, reflected) <= 0) {
249                        // Accept the contraction point.
250                        replaceWorstPoint(outContracted, comparator);
251                        return;
252                    }
253                } else {
254                    // Perform an inside contraction.
255                    final double[] xC = new double[n];
256                    for (int j = 0; j < n; j++) {
257                        xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
258                    }
259                    final PointValuePair inContracted
260                        = new PointValuePair(xC, evaluationFunction.value(xC), false);
261    
262                    if (comparator.compare(inContracted, worst) < 0) {
263                        // Accept the contraction point.
264                        replaceWorstPoint(inContracted, comparator);
265                        return;
266                    }
267                }
268    
269                // Perform a shrink.
270                final double[] xSmallest = getPoint(0).getPointRef();
271                for (int i = 1; i <= n; i++) {
272                    final double[] x = getPoint(i).getPoint();
273                    for (int j = 0; j < n; j++) {
274                        x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
275                    }
276                    setPoint(i, new PointValuePair(x, Double.NaN, false));
277                }
278                evaluate(evaluationFunction, comparator);
279            }
280        }
281    }