1 /** Functions related to the solving of nonlinear equations, i.e. finding 2 roots of nonlinear functions. 3 4 Authors: Lars Tandle Kyllingstad 5 Copyright: Copyright (c) 2009-2010, Lars T. Kyllingstad. All rights reserved. 6 License: Boost License 1.0 7 */ 8 module scid.nonlinear; 9 10 11 import std.algorithm; 12 import std.exception; 13 import std.functional; 14 import std.math; 15 static import std.numeric; 16 import std.range; 17 import std.traits; 18 import std.typecons; 19 20 import scid.core.fortran; 21 import scid.core.memory; 22 import scid.core.traits; 23 import scid.ports.minpack.hybrd; 24 import scid.util; 25 26 27 28 29 /** Searches for a root of N functions of N variables using a variant 30 of the Powell Hybrid Method (the HYBRD routine from MINPACK). 31 32 Params: 33 f = The set of equations, given as a function, delegate or 34 functor that takes an array of length N as input and 35 returns an array of length N. If the function has 36 an additional array input parameter this will be 37 assumed to be a buffer for the output value. 38 guess = A starting point for the algorithm. The closer this 39 guess is to the true root, the greater chance that the 40 algorithm converges. 41 epsRel = Success criterion: The algorithm stops when 42 the relative error between two consecutive iterations 43 is at most epsRel. 44 maxFuncEvals = (optional) The maximum number of function evaluations. 45 If maxFuncEvals<1, it is set to 200*(N+1). 46 buffer = (optional) A buffer of length at least N, for the return value. 47 48 Example: 49 The Rosenbrock function is a commonly used test problem for 50 optimisation algorithms. It has a global minimum at (1,1) that is hard 51 to locate numerically because it lies in a long, narrow valley. 52 Instead of using an optimisation algorithm, let us try to locate 53 the minimum by finding the root of the Rosenbrock function's 54 gradient. 55 --- 56 // The Rosenbrock function is defined as 57 // f(x,y) = (1-x)^2 + 100 (y-x^2)^2. 58 // Thus, its gradient is: 59 real[] dRosenbrock(real[] v, real[] buf) 60 { 61 auto x = v[0], y = v[1]; 62 buf[0] = -2*(1-x) - 400*x*(y-x*x); 63 buf[1] = 200 * (y-x*x); 64 65 return buf; 66 } 67 68 real[] guess = [ 2.0, 2.0 ]; 69 auto root = findRoot(&dRosenbrock, guess, 0.0L); 70 71 writeln(root); // Prints "1 1". Yay! 72 --- 73 */ 74 Real[] findRoot (Real, Func) 75 ( 76 scope Func f, 77 Real[] guess, 78 Real epsRel, 79 int maxFuncEvals = 0, 80 Real[] buffer=null 81 ) 82 if (isFloatingPoint!Real && isVectorField!(Func, Real)) 83 in 84 { 85 assert (guess.length > 0, "findRoot: empty guess vector given"); 86 } 87 body 88 { 89 mixin (newFrame); 90 91 // Wrap the user-supplied function. 92 void fcn(size_t m, Real* x, Real* fvec, ref int iflag) 93 { 94 static if (isBufferVectorField!(Func, Real)) 95 { 96 // When the function takes a buffer, we check that it 97 // actually uses it. 98 auto fTest = f(x[0 .. m], fvec[0 .. m]); 99 assert (fTest.ptr == fvec); 100 assert (fTest.length == m, "findRoot: The number of " 101 ~"equations must be equal to the number of variables"); 102 } 103 else 104 { 105 fvec[0 .. m] = f(x[0 .. m])[]; 106 } 107 } 108 109 immutable int n = toInt(guess.length); 110 immutable int wslen = (n*(3*n + 15))/2; 111 if (maxFuncEvals < 1) maxFuncEvals = 200*(n+1); 112 113 // Copy the guessed vector into the buffer. 114 buffer.length = n; 115 buffer[] = guess[]; 116 117 // There are a lot of parameters to the hybrd function, and we set them 118 // in the "correct" order and use the "correct" names. 119 Real* x = buffer.ptr; 120 Real* fvec = cast(Real*) TempAlloc.malloc(wslen*Real.sizeof); 121 alias epsRel xtol; 122 alias maxFuncEvals maxfev; 123 size_t ml_mu = n-1; 124 Real epsfcn = 0.0; 125 Real* diag = fvec + n; diag[0 .. n] = 1.0; 126 int mode = 2; 127 enum Real factor = 100.0; 128 int nprint = 0; 129 int info = 0; 130 uint nfev; 131 Real* fjac = diag + n; 132 alias n ldfjac; 133 Real* r = fjac + ldfjac*n; 134 size_t lr = (n*(n+1))/2; 135 Real* qtf = r + lr; 136 Real* wa1 = qtf + n; 137 Real* wa2 = wa1 + n; 138 Real* wa3 = wa2 + n; 139 Real* wa4 = wa3 + n; 140 141 // Phew! Call hybrd() now. 142 hybrd!(Real, typeof(&fcn))(&fcn, n, x, fvec, xtol, maxfev, ml_mu, ml_mu, 143 epsfcn, diag, mode, factor, nprint, info, nfev, fjac, ldfjac, r, lr, 144 qtf, wa1, wa2, wa3, wa4); 145 146 switch (info) 147 { 148 case 1: // Success! 149 return x[0 .. n]; 150 case 0: 151 throw new Exception("Invalid input parameters"); 152 case 2: 153 throw new Exception( 154 "The function has been called the maximum number of times"); 155 case 3: 156 throw new Exception("Cannot reach the requested accuracy"); 157 case 4: 158 case 5: 159 throw new Exception("Algorithm failed to converge"); 160 default: 161 assert(0); 162 } 163 } 164 165 unittest 166 { 167 real[] dRosenbrockB(real[] v, real[] fx) 168 { 169 assert (v.length == 2 && fx.length == 2); 170 auto x = v[0], y = v[1]; 171 fx[0] = -2*(1-x) - 400*x*(y-x*x); 172 fx[1] = 200 * (y - x*x); 173 return fx; 174 } 175 real[] dRosenbrock(real[] v) 176 { 177 auto fx = new real[2]; 178 return dRosenbrockB(v, fx); 179 } 180 181 real[] guess = [ 2.0, 2.0 ]; 182 auto root = findRoot(&dRosenbrock, guess, 0.0L); // Test bufferless function 183 assert (approxEqual(root, [1.0L, 1.0L].dup, 1e-6)); 184 auto rootB = findRoot(&dRosenbrockB, guess, 0.0L); // Test buffered function 185 assert (approxEqual(rootB, [1.0L, 1.0L].dup, 1e-6)); 186 } 187 188 189 190 191 /** Find a root of the function f. 192 193 This function first calls $(LINK2 #bracketRoot,bracketRoot) to 194 obtain an interval inside which there must be a root, and then calls 195 $(LINK2 http://www.digitalmars.com/d/2.0/phobos/std_numeric.html#findRoot,std.numeric.findRoot()) 196 to pin down the location of the root. 197 198 The parameters x0, scale, xMin, and xMax are just passed on to 199 $(LINK2 #bracketRoot,bracketRoot), and they are described in detail 200 in its documentation. In brief, x0 should be an estimate of the 201 root's location, while scale should be a characteristic scale for 202 the function, i.e. a distance over which the function changes 203 significantly. [xMin,xMax] is the interval inside which the algorithm 204 is allowed to search. 205 206 You may specify the desired (minimum) number of digits of precision 207 in the answer. If this is left out, the algorithm will attempt 208 to achieve full machine precision. 209 */ 210 T findRoot(F, T)(scope F f, T x0, T scale, T xMin, T xMax, int precision) 211 if (isUnaryFunction!(F, T) && isFloatingPoint!T) 212 { 213 return findRootImpl(f, x0, scale, xMin, xMax, 214 (T a, T b) { return matchDigits(a, b, precision); }); 215 } 216 217 /// ditto 218 T findRoot(F, T)(scope F f, T x0, T scale, int precision) 219 if (isUnaryFunction!(F, T) && isFloatingPoint!T) 220 { 221 return findRoot(f, x0, scale, -T.infinity, T.infinity, precision); 222 } 223 224 /// ditto 225 T findRoot(F, T) 226 (scope F f, T x0, T scale, T xMin = -T.infinity, T xMax = T.infinity) 227 if (isUnaryFunction!(F, T) && isFloatingPoint!T) 228 { 229 return findRootImpl(f, x0, scale, xMin, xMax, 230 (T a, T b) { return false; }); 231 } 232 233 234 // Implementation of findRoot() 235 private T findRootImpl(F, T) 236 ( 237 scope F f, 238 T x0, T scale, 239 T xMin, T xMax, 240 scope bool delegate(T, T) tolerance 241 ) 242 if (isUnaryFunction!(F, T) && isFloatingPoint!T) 243 { 244 auto bracket = bracketRoot(f, x0, scale, xMin, xMax); 245 if (bracket.y1 == 0) return bracket.x1; 246 if (bracket.y2 == 0) return bracket.x2; 247 248 // std.numeric.findRoot() only takes a delegate 249 static if (is (F == delegate)) 250 auto dg = f; 251 else static if (isFunctionPointer!F) 252 auto dg = toDelegate(f); 253 else static if (isFunctor!F) 254 scope ReturnType!F delegate(ParameterTypeTuple!F) dg = &f.opCall; 255 256 return std.numeric.findRoot(dg, 257 bracket.x1, bracket.x2, 258 bracket.y1, bracket.y2, 259 tolerance 260 )[0]; 261 } 262 263 264 unittest 265 { 266 real f(real x) { return log(x); } 267 268 immutable inaccurateRoot = 269 findRoot(&f, 0.5L, 1.0L, real.epsilon, real.infinity, 2); 270 assert (matchDigits(inaccurateRoot, 1.0, 2)); 271 assert (!matchDigits(inaccurateRoot, 1.0, 10)); 272 273 immutable accurateRoot = 274 findRoot(&f, 0.5L, 1.0L, real.epsilon, real.infinity); 275 assert (accurateRoot == 1.0); 276 } 277 278 unittest 279 { 280 // Function 281 static real f(real x) { return x; } 282 assert (findRoot(&f, 1.0L, 1.0L) == 0.0L); 283 } 284 285 unittest 286 { 287 // Functor 288 struct Functor { real opCall(real x) { return x^^3; } } 289 Functor g; 290 assert (findRoot(g, 1.0L, 1.0L) == 0.0L); 291 } 292 293 294 295 296 /** Bracket a root of the function f. 297 298 If a function f(x) is continuous on an interval [x1,x2], 299 and f(x1) and f(x2) have opposite sign, we know the function 300 must pass through zero somewhere in the interval. 301 The points x1 and x2 are then said to 'bracket' the 302 root. This is usually the first step in locating the root 303 of a function. 304 305 If this function succeeds, it returns a RootBracket containing 306 the points x1 and x2, together with the function values f(x1) 307 and f(x2). If it fails, an exception is thrown. 308 Note that this library considers the points to be bracketing 309 a root also if the root is located exactly at x1 and/or x2, 310 i.e. if f(x1)=0 and/or f(x2)=0. 311 312 Details: 313 314 This function will start by evaluating f(x) in the points 315 x0 and x0+scale and see if those 316 points bracket a root of the given function. If not, the interval 317 is expanded geometrically (i.e. the distance between the points 318 is multiplied by a constant factor), always in the direction where 319 f(x) is smallest, until the points bracket a root. 320 321 You may optionally specify a limiting interval [xMin, xMax], and the 322 algorithm will never search outside it. This is useful, 323 for instance, for functions that are only defined for certain 324 values of x. If you do specify such an interval, the 325 initial point x0 must lie inside it. 326 327 It is usually worthwhile analysing the behaviour of the function 328 in order to find appropriate values for x0 and scale. 329 The closer x0 is to the actual root, the fewer steps (i.e. the 330 fewer evaluations of f) this algorithm will require to succeed. 331 If scale is too large, and the function has several roots, 332 there is a chance that it will just step across both roots 333 and not find any of them. On the other hand, if it is too small, 334 it may again cause the algorithm to take more steps than would 335 otherwise be necessary. 336 */ 337 RootBracket!(T, ReturnType!F) bracketRoot(F, T) 338 ( 339 scope F f, 340 in T x0, in T scale, 341 in T xMin = -T.infinity, in T xMax = T.infinity, 342 ) 343 if (isFloatingPoint!T && isUnaryFunction!(F, T)) 344 in 345 { 346 assert (scale != 0, "scale must be nonzero"); 347 assert (xMin < xMax, "xMin must be smaller than xMax"); 348 assert (xMin <= x0 && x0 <= xMax, "x0 must be in the interval [xMin,xMax]"); 349 } 350 body 351 { 352 alias typeof(return) B; 353 enum expandFactor = 1.6; 354 355 356 // Function that searches upwards from xMin 357 B upwards(real x, real fx, real dx) 358 { 359 immutable fxMin = f(xMin); 360 for (;;) 361 { 362 if (fxMin * fx <= 0) return B(xMin, x, fxMin, fx); 363 enforce(x != xMax, "Unable to bracket root"); 364 365 dx *= expandFactor; 366 x = min(x + dx, xMax); 367 fx = f(x); 368 } 369 assert(0); 370 } 371 372 373 // Function that searches downwards from xMax 374 B downwards(real x, real fx, real dx) 375 { 376 immutable fxMax = f(xMax); 377 for (;;) 378 { 379 if (fxMax * fx <= 0) return B(x, xMax, fx, fxMax); 380 enforce(x != xMin, "Unable to bracket root"); 381 382 dx *= expandFactor; 383 x = max(x - dx, xMin); 384 fx = f(x); 385 } 386 assert(0); 387 } 388 389 390 // These are the initial points 391 real x1 = x0; 392 real x2 = x0 + scale; 393 394 // If x0 is either endpoint of the allowed interval, or if x2 395 // falls outside the interval, search only in one direction. 396 if (x1 == xMin || x2 <= xMin) 397 { 398 immutable x = min(x1 + abs(scale), xMax); 399 return upwards(x, f(x), abs(scale)); 400 } 401 if (x1 == xMax || x2 >= xMax) 402 { 403 immutable x = max(x1 - abs(scale), xMin); 404 return downwards(x, f(x), abs(scale)); 405 } 406 407 408 // Both x1 and x2 fall inside [xMin, xMax], so we use bidirectional search 409 if (x1 > x2) swap(x1, x2); 410 real fx1 = f(x1); 411 real fx2 = f(x2); 412 413 for (;;) 414 { 415 // Check whether interval brackets a root 416 if (fx1 * fx2 <= 0) return B(x1, x2, fx1, fx2); 417 418 // Expand interval in the direction where f(x) is closest to zero. 419 if (fabs(fx1) < fabs(fx2)) 420 { 421 x1 += expandFactor * (x1 - x2); 422 if (x1 <= xMin) return upwards(x2, fx2, x2-x1); 423 fx1 = f(x1); 424 } 425 else 426 { 427 x2 += expandFactor * (x2 - x1); 428 if (x2 >= xMax) return downwards(x1, fx1, x2-x1); 429 fx2 = f(x2); 430 } 431 } 432 assert(0); 433 } 434 435 436 unittest 437 { 438 real f(real x) { return 1 - x; } 439 auto b = bracketRoot(&f, -100.0L, 1.0L); 440 assert (b.contains(1)); 441 } 442 443 unittest 444 { 445 real f(real x) { return log(x); } 446 auto b = bracketRoot(&f, 2*real.epsilon, 0.1L, real.epsilon, real.infinity); 447 assert (b.contains(1)); 448 } 449 450 451 452 453 /** A set of points that bracket a root of some function. */ 454 struct RootBracket(X, Y) 455 { 456 /// Two points that bracket a root. 457 X x1; 458 X x2; /// ditto 459 460 /// The function value at x1 and x2, respectively 461 Y y1; 462 Y y2; /// ditto 463 464 465 version(unittest) private bool contains(X point) 466 { 467 if (x1 <= x2) return (point >= x1 && point <= x2); 468 else return (point >= x2 && point <= x1); 469 } 470 } 471 472 473 474 475 /** Uses bracketRoots() to divide the interval [a,b] into subintervals 476 and check which ones bracket roots. Then, findRoot() is applied 477 to each bracketing interval, and an array containing the roots 478 is returned. 479 480 A buffer of length at least nIntervals+1, for storing the roots, 481 may optionally be provided. 482 */ 483 T[] findRoots(T, Func)(scope Func f, T a, T b, uint nIntervals, 484 T[] buffer=null) 485 { 486 mixin(scid.core.memory.newFrame); 487 488 // Find bracketing subintervals. 489 auto bracketBuffer = 490 newStack!(RootBracket!(T, ReturnType!Func))(nIntervals+1); 491 auto intervals = 492 bracketRoots!(T,Func)(f, a, b, nIntervals, bracketBuffer); 493 494 // Find all the bracketed roots. 495 buffer.length = intervals.length; 496 foreach (i, iv; intervals) 497 { 498 // Check if a root is located at the "lower" endpoint. 499 if (iv.y1 == 0) buffer[i] = iv.x1; 500 501 // If not, call findRoot() to locate the root. 502 // Note that if it is located at the "higher" end point it will 503 // be caught in the next iteration. 504 else if (iv.y2 != 0) buffer[i] = 505 std.numeric.findRoot(f, iv.x1, iv.x2, iv.y1, iv.y2, 506 (T a, T b) { return false; })[0]; 507 } 508 509 return buffer; 510 } 511 512 513 unittest 514 { 515 real f(real x) 516 { 517 return (2+x) * (1+x) * x * (1-x) * (2-x); 518 } 519 520 auto r = findRoots(&f, -2.0L, 2.0L, 15); 521 assert (r.length == 5); 522 assert (approxEqual(r, [-2.0, -1.0, 0.0, 1.0, 2.0], real.epsilon)); 523 } 524 525 526 527 528 /** Divides the interval [a,b] into the given number of equal-sized 529 subintervals, 530 checks whether any of the subintervals bracket a root, and returns 531 the ones that do, together with the function values at those points. 532 533 A buffer of length at least nIntervals+1, for storing the brackets, may 534 optionally be provided. If not, one will be allocated. 535 */ 536 RootBracket!(T, ReturnType!Func)[] bracketRoots(T, Func) 537 (scope Func f, T a, T b, uint nIntervals, 538 RootBracket!(T, ReturnType!Func)[] buffer = null) 539 { 540 static assert (is (typeof(buffer) == typeof(return))); 541 alias ElementType!(typeof(return)) B; 542 543 buffer.length = nIntervals+1; 544 int numBrackets = 0; 545 546 auto lo = a; 547 auto flo = f(lo); 548 immutable step = (b - a)/nIntervals; 549 550 foreach (i; 0 .. nIntervals) 551 { 552 immutable hi = (i < nIntervals - 1 ? lo + step : b); 553 immutable fhi = f(hi); 554 555 if (flo == 0 || flo * fhi < 0) 556 { 557 B br; 558 br.x1 = lo; br.y1 = flo; 559 br.x2 = hi; br.y2 = fhi; 560 buffer[numBrackets] = br; 561 ++numBrackets; 562 } 563 564 lo = hi; 565 flo = fhi; 566 } 567 568 // Check for a root in the endpoint as well. 569 if (flo == 0) 570 { 571 B br; 572 br.x1 = lo; br.y1 = flo; 573 br.x2 = lo; br.y2 = flo; 574 buffer[numBrackets] = br; 575 ++numBrackets; 576 } 577 578 buffer.length = numBrackets; 579 return buffer; 580 } 581 582 583 unittest 584 { 585 real f(real x) 586 { 587 return (2+x) * (1+x) * x * (1-x) * (2-x); 588 } 589 590 auto b = bracketRoots(&f, -2.0L, 2.0L, 15); 591 assert (b.length == 5); 592 foreach (i; b) 593 { 594 assert (i.y1 == f(i.x1)); 595 assert (i.y2 == f(i.x2)); 596 } 597 assert (b[0].x1 == -2); 598 assert (b[1].contains(-1)); 599 assert (b[2].contains(0)); 600 assert (b[3].contains(1)); 601 assert (b[4].x1 == 2); 602 } 603 604 605 606 607 /** Use bisection to find the point where the given predicate goes from 608 returning false to returning true. 609 610 Params: 611 f = The function. 612 predicate = The predicate, which must take a point and 613 the function value at that point and return 614 a boolean. 615 xTrue = A point where the predicate is true. 616 xFalse = A point where the predicate is false. 617 fTrue = (optional) The value of f at xTrue. 618 fFalse = (optional) The value of f at xFalse. 619 xTolerance = Success: When the absolute distance between 620 xTrue and xFalse is less than this number, 621 the function returns. 622 maxIterations = Failure: When the algorithm has failed to 623 produce a result after maxIterations bisections, 624 an exception is thrown. 625 626 Returns: 627 A tuple containing values named xTrue, xFalse, fTrue, and fFalse, which 628 satisfy 629 --- 630 f(xTrue) == fTrue 631 f(xFalse) == fFalse 632 predicate(xTrue, fTrue) == true 633 predicate(xFalse, fFalse) == false 634 abs(xTrue-xFalse) <= xTolerance 635 --- 636 637 Example: 638 --- 639 // Find a root by bisection 640 auto r = bisect( 641 (real x) { return x^^3; }, 642 (real x, real fx) { return fx < 0; }, 643 -1.0L, 1.5L, 1e-10L 644 ); 645 646 // Let's check if we got the right answer. 647 enum root = 0.0L; 648 assert (abs(r.xTrue - root) <= 1e-10); 649 assert (abs(r.xFalse - root) <= 1e-10); 650 651 assert (r.fTrue < 0); 652 assert (r.fFalse >= 0); 653 assert (abs(r.xTrue - r.xFalse) <= 1e-10); 654 assert (r.xNaN < 0); 655 assert (abs(r.xValid - r.xNan) <= 1e-6); 656 --- 657 */ 658 Tuple!(T, "xTrue", T, "xFalse", R, "fTrue", R, "fFalse") 659 bisect(F, T, R = ReturnType!F) 660 (scope F f, bool delegate(T, R) predicate, T xTrue, T xFalse, 661 T xTolerance, int maxIterations=40) 662 { 663 return bisect(f, predicate, xTrue, xFalse, f(xTrue), f(xFalse), 664 xTolerance, maxIterations); 665 } 666 667 668 /// ditto 669 Tuple!(T, "xTrue", T, "xFalse", R, "fTrue", R, "fFalse") 670 bisect(F, T, R = ReturnType!F) 671 (scope F f, bool delegate(T, R) predicate, T xTrue, T xFalse, 672 R fTrue, R fFalse, T xTolerance, int maxIterations=40) 673 if (isFloatingPoint!T && isFloatingPoint!R) 674 in 675 { 676 assert (predicate(xTrue, fTrue) == true, "Predicate is false at xTrue"); 677 assert (predicate(xFalse, fFalse) == false, "Predicate is true at xFalse"); 678 assert (xTolerance > 0, "xTolerance must be positive"); 679 } 680 body 681 { 682 foreach (i; 0 .. maxIterations) 683 { 684 if (fabs(xTrue-xFalse) <= xTolerance) 685 return typeof(return)(xTrue, xFalse, fTrue, fFalse); 686 687 immutable xMid = (xTrue + xFalse) / 2; 688 immutable fMid = f(xMid); 689 if (predicate(xMid, fMid)) 690 { 691 xTrue = xMid; 692 fTrue = fMid; 693 } 694 else 695 { 696 xFalse = xMid; 697 fFalse = fMid; 698 } 699 } 700 701 throw new Exception("The maximum number of iterations was reached"); 702 } 703 704 705 unittest 706 { 707 auto r = bisect( 708 (real x) { return x^^3; }, 709 (real x, real fx) { return fx < 0; }, 710 -1.0L, 1.5L, 1e-10L 711 ); 712 713 enum root = 0.0L; 714 assert (abs(r.xTrue - root) <= 1e-10); 715 assert (abs(r.xFalse - root) <= 1e-10); 716 717 assert (r.fTrue == r.xTrue^^3); 718 assert (r.fFalse == r.xFalse^^3); 719 assert (r.fTrue < 0); 720 assert (r.fFalse >= 0); 721 assert (abs(r.xTrue - r.xFalse) <= 1e-10); 722 }