1 /** Functions related to the solving of nonlinear equations, i.e. finding
2     roots of nonlinear functions.
3 
4     Authors:    Lars Tandle Kyllingstad
5     Copyright:  Copyright (c) 2009-2010, Lars T. Kyllingstad. All rights reserved.
6     License:    Boost License 1.0
7 */
8 module scid.nonlinear;
9 
10 
11 import std.algorithm;
12 import std.exception;
13 import std.functional;
14 import std.math;
15 static import std.numeric;
16 import std.range;
17 import std.traits;
18 import std.typecons;
19 
20 import scid.core.fortran;
21 import scid.core.memory;
22 import scid.core.traits;
23 import scid.ports.minpack.hybrd;
24 import scid.util;
25 
26 
27 
28 
29 /** Searches for a root of N functions of N variables using a variant
30     of the Powell Hybrid Method (the HYBRD routine from MINPACK).
31 
32     Params:
33         f = The set of equations, given as a function, delegate or
34             functor that takes an array of length N as input and
35             returns an array of length N. If the function has
36             an additional array input parameter this will be
37             assumed to be a buffer for the output value.
38         guess = A starting point for the algorithm. The closer this
39             guess is to the true root, the greater chance that the
40             algorithm converges.
41         epsRel = Success criterion: The algorithm stops when
42             the relative error between two consecutive iterations
43             is at most epsRel.
44         maxFuncEvals = (optional) The maximum number of function evaluations.
45             If maxFuncEvals<1, it is set to 200*(N+1).
46         buffer = (optional) A buffer of length at least N, for the return value.
47 
48     Example:
49     The Rosenbrock function is a commonly used test problem for
50     optimisation algorithms. It has a global minimum at (1,1) that is hard
51     to locate numerically because it lies in a long, narrow valley.
52     Instead of using an optimisation algorithm, let us try to locate
53     the minimum by finding the root of the Rosenbrock function's
54     gradient.
55     ---
56     // The Rosenbrock function is defined as
57     //     f(x,y) = (1-x)^2 + 100 (y-x^2)^2.
58     // Thus, its gradient is:
59     real[] dRosenbrock(real[] v, real[] buf)
60     {
61         auto x = v[0], y = v[1];
62         buf[0] = -2*(1-x) - 400*x*(y-x*x);
63         buf[1] = 200 * (y-x*x);
64 
65         return buf;
66     }
67 
68     real[] guess = [ 2.0, 2.0 ];
69     auto root = findRoot(&dRosenbrock, guess, 0.0L);
70 
71     writeln(root);  // Prints "1 1". Yay!
72     ---
73 */
74 Real[] findRoot (Real, Func)
75     (
76         scope Func f,
77         Real[] guess,
78         Real epsRel,
79         int maxFuncEvals = 0,
80         Real[] buffer=null
81     )
82     if (isFloatingPoint!Real && isVectorField!(Func, Real))
83 in
84 {
85     assert (guess.length > 0, "findRoot: empty guess vector given");
86 }
87 body
88 {
89     mixin (newFrame);
90 
91     // Wrap the user-supplied function.
92     void fcn(size_t m, Real* x, Real* fvec, ref int iflag)
93     {
94         static if (isBufferVectorField!(Func, Real))
95         {
96             // When the function takes a buffer, we check that it
97             // actually uses it.
98             auto fTest = f(x[0 .. m], fvec[0 .. m]);
99             assert (fTest.ptr == fvec);
100             assert (fTest.length == m, "findRoot: The number of "
101                 ~"equations must be equal to the number of variables");
102         }
103         else
104         {
105             fvec[0 .. m] = f(x[0 .. m])[];
106         }
107     }
108 
109     immutable int n = toInt(guess.length);
110     immutable int wslen = (n*(3*n + 15))/2;
111     if (maxFuncEvals < 1) maxFuncEvals = 200*(n+1);
112 
113     // Copy the guessed vector into the buffer.
114     buffer.length = n;
115     buffer[] = guess[];
116 
117     // There are a lot of parameters to the hybrd function, and we set them
118     // in the "correct" order and use the "correct" names.
119     Real* x = buffer.ptr;
120     Real* fvec = cast(Real*) TempAlloc.malloc(wslen*Real.sizeof);
121     alias epsRel xtol;
122     alias maxFuncEvals maxfev;
123     size_t ml_mu = n-1;
124     Real epsfcn = 0.0;
125     Real* diag = fvec + n;  diag[0 .. n] = 1.0;
126     int mode = 2;
127     enum Real factor = 100.0;
128     int nprint = 0;
129     int info = 0;
130     uint nfev;
131     Real* fjac = diag + n;
132     alias n ldfjac;
133     Real* r = fjac + ldfjac*n;
134     size_t lr = (n*(n+1))/2;
135     Real* qtf = r + lr;
136     Real* wa1 = qtf + n;
137     Real* wa2 = wa1 + n;
138     Real* wa3 = wa2 + n;
139     Real* wa4 = wa3 + n;
140 
141     // Phew! Call hybrd() now.
142     hybrd!(Real, typeof(&fcn))(&fcn, n, x, fvec, xtol, maxfev, ml_mu, ml_mu,
143         epsfcn, diag, mode, factor, nprint, info, nfev, fjac, ldfjac, r, lr,
144         qtf, wa1, wa2, wa3, wa4);
145 
146     switch (info)
147     {
148         case 1: // Success!
149             return x[0 .. n];
150         case 0:
151             throw new Exception("Invalid input parameters");
152         case 2:
153             throw new Exception(
154                 "The function has been called the maximum number of times");
155         case 3:
156             throw new Exception("Cannot reach the requested accuracy");
157         case 4:
158         case 5:
159             throw new Exception("Algorithm failed to converge");
160         default:
161             assert(0);
162     }
163 }
164 
165 unittest
166 {
167     real[] dRosenbrockB(real[] v, real[] fx)
168     {
169         assert (v.length == 2 && fx.length == 2);
170         auto x = v[0], y = v[1];
171         fx[0] = -2*(1-x) - 400*x*(y-x*x);
172         fx[1] = 200 * (y - x*x);
173         return fx;
174     }
175     real[] dRosenbrock(real[] v)
176     {
177         auto fx = new real[2];
178         return dRosenbrockB(v, fx);
179     }
180 
181     real[] guess = [ 2.0, 2.0 ];
182     auto root = findRoot(&dRosenbrock, guess, 0.0L); // Test bufferless function
183     assert (approxEqual(root, [1.0L, 1.0L].dup, 1e-6));
184     auto rootB = findRoot(&dRosenbrockB, guess, 0.0L); // Test buffered function
185     assert (approxEqual(rootB, [1.0L, 1.0L].dup, 1e-6));
186 }
187 
188 
189 
190 
191 /** Find a root of the function f.
192 
193     This function first calls $(LINK2 #bracketRoot,bracketRoot) to
194     obtain an interval inside which there must be a root, and then calls
195     $(LINK2 http://www.digitalmars.com/d/2.0/phobos/std_numeric.html#findRoot,std.numeric.findRoot())
196     to pin down the location of the root.
197 
198     The parameters x0, scale, xMin, and xMax are just passed on to
199     $(LINK2 #bracketRoot,bracketRoot), and they are described in detail
200     in its documentation.  In brief, x0 should be an estimate of the
201     root's location, while scale should be a characteristic scale for
202     the function, i.e. a distance over which the function changes
203     significantly.  [xMin,xMax] is the interval inside which the algorithm
204     is allowed to search.
205 
206     You may specify the desired (minimum) number of digits of precision
207     in the answer.  If this is left out, the algorithm will attempt
208     to achieve full machine precision.
209 */
210 T findRoot(F, T)(scope F f, T x0, T scale, T xMin, T xMax, int precision)
211     if (isUnaryFunction!(F, T) && isFloatingPoint!T)
212 {
213     return findRootImpl(f, x0, scale, xMin, xMax,
214         (T a, T b) { return matchDigits(a, b, precision); });
215 }
216 
217 /// ditto
218 T findRoot(F, T)(scope F f, T x0, T scale, int precision)
219     if (isUnaryFunction!(F, T) && isFloatingPoint!T)
220 {
221     return findRoot(f, x0, scale, -T.infinity, T.infinity, precision);
222 }
223 
224 /// ditto
225 T findRoot(F, T)
226     (scope F f, T x0, T scale, T xMin = -T.infinity, T xMax = T.infinity)
227     if (isUnaryFunction!(F, T) && isFloatingPoint!T)
228 {
229     return findRootImpl(f, x0, scale, xMin, xMax,
230         (T a, T b) { return false; });
231 }
232 
233 
234 // Implementation of findRoot()
235 private T findRootImpl(F, T)
236     (
237         scope F f,
238         T x0, T scale,
239         T xMin, T xMax,
240         scope bool delegate(T, T) tolerance
241     )
242     if (isUnaryFunction!(F, T) && isFloatingPoint!T)
243 {
244     auto bracket = bracketRoot(f, x0, scale, xMin, xMax);
245     if (bracket.y1 == 0) return bracket.x1;
246     if (bracket.y2 == 0) return bracket.x2;
247 
248     // std.numeric.findRoot() only takes a delegate
249     static if (is (F == delegate))
250         auto dg = f;
251     else static if (isFunctionPointer!F)
252         auto dg = toDelegate(f);
253     else static if (isFunctor!F)
254         scope ReturnType!F delegate(ParameterTypeTuple!F) dg = &f.opCall;
255 
256     return std.numeric.findRoot(dg,
257         bracket.x1, bracket.x2,
258         bracket.y1, bracket.y2,
259         tolerance
260     )[0];
261 }
262 
263 
264 unittest
265 {
266     real f(real x) { return log(x); }
267 
268     immutable inaccurateRoot =
269         findRoot(&f, 0.5L, 1.0L, real.epsilon, real.infinity, 2);
270     assert (matchDigits(inaccurateRoot, 1.0, 2));
271     assert (!matchDigits(inaccurateRoot, 1.0, 10));
272 
273     immutable accurateRoot =
274         findRoot(&f, 0.5L, 1.0L, real.epsilon, real.infinity);
275     assert (accurateRoot == 1.0);
276 }
277 
278 unittest
279 {
280     // Function
281     static real f(real x) { return x; }
282     assert (findRoot(&f, 1.0L, 1.0L) == 0.0L);
283 }
284 
285 unittest
286 {
287     // Functor
288     struct Functor { real opCall(real x) { return x^^3; } }
289     Functor g;
290     assert (findRoot(g, 1.0L, 1.0L) == 0.0L);
291 }
292 
293 
294 
295 
296 /** Bracket a root of the function f.
297 
298     If a function f(x) is continuous on an interval [x1,x2],
299     and f(x1) and f(x2) have opposite sign, we know the function
300     must pass through zero somewhere in the interval.
301     The points x1 and x2 are then said to 'bracket' the
302     root.  This is usually the first step in locating the root
303     of a function.
304 
305     If this function succeeds, it returns a RootBracket containing
306     the points x1 and x2, together with the function values f(x1)
307     and f(x2).  If it fails, an exception is thrown.
308     Note that this library considers the points to be bracketing
309     a root also if the root is located exactly at x1 and/or x2,
310     i.e. if f(x1)=0 and/or f(x2)=0.
311 
312     Details:
313 
314     This function will start by evaluating f(x) in the points
315     x0 and x0+scale and see if those
316     points bracket a root of the given function.  If not, the interval
317     is expanded geometrically (i.e. the distance between the points
318     is multiplied by a constant factor), always in the direction where
319     f(x) is smallest, until the points bracket a root.
320 
321     You may optionally specify a limiting interval [xMin, xMax], and the
322     algorithm will never search outside it.  This is useful,
323     for instance, for functions that are only defined for certain
324     values of x.  If you do specify such an interval, the
325     initial point x0 must lie inside it.
326 
327     It is usually worthwhile analysing the behaviour of the function
328     in order to find appropriate values for x0 and scale.
329     The closer x0 is to the actual root, the fewer steps (i.e. the
330     fewer evaluations of f) this algorithm will require to succeed.
331     If scale is too large, and the function has several roots,
332     there is a chance that it will just step across both roots
333     and not find any of them.  On the other hand, if it is too small,
334     it may again cause the algorithm to take more steps than would
335     otherwise be necessary.
336 */
337 RootBracket!(T, ReturnType!F) bracketRoot(F, T)
338     (
339         scope F f,
340         in T x0, in T scale,
341         in T xMin = -T.infinity, in T xMax = T.infinity,
342     )
343     if (isFloatingPoint!T && isUnaryFunction!(F, T))
344 in
345 {
346     assert (scale != 0, "scale must be nonzero");
347     assert (xMin < xMax, "xMin must be smaller than xMax");
348     assert (xMin <= x0 && x0 <= xMax, "x0 must be in the interval [xMin,xMax]");
349 }
350 body
351 {
352     alias typeof(return) B;
353     enum expandFactor = 1.6;
354 
355 
356     // Function that searches upwards from xMin
357     B upwards(real x, real fx, real dx)
358     {
359         immutable fxMin = f(xMin);
360         for (;;)
361         {
362             if (fxMin * fx <= 0) return B(xMin, x, fxMin, fx);
363             enforce(x != xMax, "Unable to bracket root");
364 
365             dx *= expandFactor;
366             x = min(x + dx, xMax);
367             fx = f(x);
368         }
369         assert(0);
370     }
371 
372 
373     // Function that searches downwards from xMax
374     B downwards(real x, real fx, real dx)
375     {
376         immutable fxMax = f(xMax);
377         for (;;)
378         {
379             if (fxMax * fx <= 0) return B(x, xMax, fx, fxMax);
380             enforce(x != xMin, "Unable to bracket root");
381 
382             dx *= expandFactor;
383             x = max(x - dx, xMin);
384             fx = f(x);
385         }
386         assert(0);
387     }
388 
389 
390     // These are the initial points
391     real x1 = x0;
392     real x2 = x0 + scale;
393 
394     // If x0 is either endpoint of the allowed interval, or if x2
395     // falls outside the interval, search only in one direction.
396     if (x1 == xMin || x2 <= xMin)
397     {
398         immutable x = min(x1 + abs(scale), xMax);
399         return upwards(x, f(x), abs(scale));
400     }
401     if (x1 == xMax || x2 >= xMax)
402     {
403         immutable x = max(x1 - abs(scale), xMin);
404         return downwards(x, f(x), abs(scale));
405     }
406 
407 
408     // Both x1 and x2 fall inside [xMin, xMax], so we use bidirectional search
409     if (x1 > x2) swap(x1, x2);
410     real fx1 = f(x1);
411     real fx2 = f(x2);
412 
413     for (;;)
414     {
415         // Check whether interval brackets a root
416         if (fx1 * fx2 <= 0)  return B(x1, x2, fx1, fx2);
417 
418         // Expand interval in the direction where f(x) is closest to zero.
419         if (fabs(fx1) < fabs(fx2))
420         {
421             x1 += expandFactor * (x1 - x2);
422             if (x1 <= xMin) return upwards(x2, fx2, x2-x1);
423             fx1 = f(x1);
424         }
425         else
426         {
427             x2 += expandFactor * (x2 - x1);
428             if (x2 >= xMax) return downwards(x1, fx1, x2-x1);
429             fx2 = f(x2);
430         }
431     }
432     assert(0);
433 }
434 
435 
436 unittest
437 {
438     real f(real x) { return 1 - x; }
439     auto b = bracketRoot(&f, -100.0L, 1.0L);
440     assert (b.contains(1));
441 }
442 
443 unittest
444 {
445     real f(real x) { return log(x); }
446     auto b = bracketRoot(&f, 2*real.epsilon, 0.1L, real.epsilon, real.infinity);
447     assert (b.contains(1));
448 }
449 
450 
451 
452 
453 /** A set of points that bracket a root of some function. */
454 struct RootBracket(X, Y)
455 {
456     /// Two points that bracket a root.
457     X x1;
458     X x2;   /// ditto
459 
460     /// The function value at x1 and x2, respectively
461     Y y1;
462     Y y2;   /// ditto
463 
464 
465     version(unittest) private bool contains(X point)
466     {
467         if (x1 <= x2) return (point >= x1 && point <= x2);
468         else          return (point >= x2 && point <= x1);
469     }
470 }
471 
472 
473 
474 
475 /** Uses bracketRoots() to divide the interval [a,b] into subintervals
476     and check which ones bracket roots.  Then, findRoot() is applied
477     to each bracketing interval, and an array containing the roots
478     is returned.
479 
480     A buffer of length at least nIntervals+1, for storing the roots,
481     may optionally be provided.
482 */
483 T[] findRoots(T, Func)(scope Func f, T a, T b, uint nIntervals,
484     T[] buffer=null)
485 {
486     mixin(scid.core.memory.newFrame);
487 
488     // Find bracketing subintervals.
489     auto bracketBuffer =
490         newStack!(RootBracket!(T, ReturnType!Func))(nIntervals+1);
491     auto intervals =
492         bracketRoots!(T,Func)(f, a, b, nIntervals, bracketBuffer);
493 
494     // Find all the bracketed roots.
495     buffer.length = intervals.length;
496     foreach (i, iv; intervals)
497     {
498         // Check if a root is located at the "lower" endpoint.
499         if (iv.y1 == 0) buffer[i] = iv.x1;
500 
501         // If not, call findRoot() to locate the root.
502         // Note that if it is located at the "higher" end point it will
503         // be caught in the next iteration.
504         else if (iv.y2 != 0) buffer[i] =
505             std.numeric.findRoot(f, iv.x1, iv.x2, iv.y1, iv.y2,
506                 (T a, T b) { return false; })[0];
507     }
508 
509     return buffer;
510 }
511 
512 
513 unittest
514 {
515     real f(real x)
516     {
517         return (2+x) * (1+x) * x * (1-x) * (2-x);
518     }
519 
520     auto r = findRoots(&f, -2.0L, 2.0L, 15);
521     assert (r.length == 5);
522     assert (approxEqual(r, [-2.0, -1.0, 0.0, 1.0, 2.0], real.epsilon));
523 }
524 
525 
526 
527 
528 /** Divides the interval [a,b] into the given number of equal-sized
529     subintervals,
530     checks whether any of the subintervals bracket a root, and returns
531     the ones that do, together with the function values at those points.
532 
533     A buffer of length at least nIntervals+1, for storing the brackets, may
534     optionally be provided.  If not, one will be allocated.
535 */
536 RootBracket!(T, ReturnType!Func)[] bracketRoots(T, Func)
537     (scope Func f, T a, T b, uint nIntervals,
538      RootBracket!(T, ReturnType!Func)[] buffer = null)
539 {
540     static assert (is (typeof(buffer) == typeof(return)));
541     alias ElementType!(typeof(return)) B;
542 
543     buffer.length = nIntervals+1;
544     int numBrackets = 0;
545 
546     auto lo = a;
547     auto flo = f(lo);
548     immutable step = (b - a)/nIntervals;
549 
550     foreach (i; 0 .. nIntervals)
551     {
552         immutable hi = (i < nIntervals - 1 ? lo + step : b);
553         immutable fhi = f(hi);
554 
555         if (flo == 0 || flo * fhi < 0)
556         {
557             B br;
558             br.x1 = lo;  br.y1 = flo;
559             br.x2 = hi;  br.y2 = fhi;
560             buffer[numBrackets] = br;
561             ++numBrackets;
562         }
563 
564         lo = hi;
565         flo = fhi;
566     }
567 
568     // Check for a root in the endpoint as well.
569     if (flo == 0)
570     {
571         B br;
572         br.x1 = lo;  br.y1 = flo;
573         br.x2 = lo;  br.y2 = flo;
574         buffer[numBrackets] = br;
575         ++numBrackets;
576     }
577 
578     buffer.length = numBrackets;
579     return buffer;
580 }
581 
582 
583 unittest
584 {
585     real f(real x)
586     {
587         return (2+x) * (1+x) * x * (1-x) * (2-x);
588     }
589 
590     auto b = bracketRoots(&f, -2.0L, 2.0L, 15);
591     assert (b.length == 5);
592     foreach (i; b)
593     {
594         assert (i.y1 == f(i.x1));
595         assert (i.y2 == f(i.x2));
596     }
597     assert (b[0].x1 == -2);
598     assert (b[1].contains(-1));
599     assert (b[2].contains(0));
600     assert (b[3].contains(1));
601     assert (b[4].x1 ==  2);
602 }
603 
604 
605 
606 
607 /** Use bisection to find the point where the given predicate goes from
608     returning false to returning true.
609 
610     Params:
611         f               =   The function.
612         predicate       =   The predicate, which must take a point and
613                             the function value at that point and return
614                             a boolean.
615         xTrue           =   A point where the predicate is true.
616         xFalse          =   A point where the predicate is false.
617         fTrue           =   (optional) The value of f at xTrue.
618         fFalse          =   (optional) The value of f at xFalse.
619         xTolerance      =   Success: When the absolute distance between
620                             xTrue and xFalse is less than this number,
621                             the function returns.
622         maxIterations   =   Failure: When the algorithm has failed to
623                             produce a result after maxIterations bisections,
624                             an exception is thrown.
625 
626     Returns:
627     A tuple containing values named xTrue, xFalse, fTrue, and fFalse, which
628     satisfy
629     ---
630     f(xTrue) == fTrue
631     f(xFalse) == fFalse
632     predicate(xTrue, fTrue) == true
633     predicate(xFalse, fFalse) == false
634     abs(xTrue-xFalse) <= xTolerance
635     ---
636 
637     Example:
638     ---
639     // Find a root by bisection
640     auto r = bisect(
641         (real x) { return x^^3; },
642         (real x, real fx) { return fx < 0; },
643         -1.0L, 1.5L, 1e-10L
644         );
645 
646     // Let's check if we got the right answer.
647     enum root = 0.0L;
648     assert (abs(r.xTrue - root) <= 1e-10);
649     assert (abs(r.xFalse - root) <= 1e-10);
650 
651     assert (r.fTrue < 0);
652     assert (r.fFalse >= 0);
653     assert (abs(r.xTrue - r.xFalse) <= 1e-10);
654     assert (r.xNaN < 0);
655     assert (abs(r.xValid - r.xNan) <= 1e-6);
656     ---
657 */
658 Tuple!(T, "xTrue", T, "xFalse", R, "fTrue", R, "fFalse")
659 bisect(F, T, R = ReturnType!F)
660     (scope F f, bool delegate(T, R) predicate, T xTrue, T xFalse,
661      T xTolerance, int maxIterations=40)
662 {
663     return bisect(f, predicate, xTrue, xFalse, f(xTrue), f(xFalse),
664         xTolerance, maxIterations);
665 }
666 
667 
668 /// ditto
669 Tuple!(T, "xTrue", T, "xFalse", R, "fTrue", R, "fFalse")
670 bisect(F, T, R = ReturnType!F)
671     (scope F f, bool delegate(T, R) predicate, T xTrue, T xFalse,
672      R fTrue, R fFalse, T xTolerance, int maxIterations=40)
673     if (isFloatingPoint!T && isFloatingPoint!R)
674 in
675 {
676     assert (predicate(xTrue, fTrue) == true, "Predicate is false at xTrue");
677     assert (predicate(xFalse, fFalse) == false, "Predicate is true at xFalse");
678     assert (xTolerance > 0, "xTolerance must be positive");
679 }
680 body
681 {
682     foreach (i; 0 .. maxIterations)
683     {
684         if (fabs(xTrue-xFalse) <= xTolerance)
685             return typeof(return)(xTrue, xFalse, fTrue, fFalse);
686 
687         immutable xMid = (xTrue + xFalse) / 2;
688         immutable fMid = f(xMid);
689         if (predicate(xMid, fMid))
690         {
691             xTrue = xMid;
692             fTrue = fMid;
693         }
694         else
695         {
696             xFalse = xMid;
697             fFalse = fMid;
698         }
699     }
700 
701     throw new Exception("The maximum number of iterations was reached");
702 }
703 
704 
705 unittest
706 {
707     auto r = bisect(
708         (real x) { return x^^3; },
709         (real x, real fx) { return fx < 0; },
710         -1.0L, 1.5L, 1e-10L
711         );
712 
713     enum root = 0.0L;
714     assert (abs(r.xTrue - root) <= 1e-10);
715     assert (abs(r.xFalse - root) <= 1e-10);
716 
717     assert (r.fTrue == r.xTrue^^3);
718     assert (r.fFalse == r.xFalse^^3);
719     assert (r.fTrue < 0);
720     assert (r.fFalse >= 0);
721     assert (abs(r.xTrue - r.xFalse) <= 1e-10);
722 }