1 /** Various useful utilities that don't really fit anywhere else. 2 3 Authors: Lars Tandle Kyllingstad 4 Copyright: Copyright (c) 2009-2010, Lars T. Kyllingstad. All rights reserved. 5 License: Boost License 1.0 6 */ 7 module scid.util; 8 9 10 import std.algorithm; 11 import std.complex; 12 import std.math; 13 import std.range; 14 import std.traits; 15 import std.typetuple: allSatisfy; 16 17 import scid.core.traits; 18 19 20 21 22 /** Check whether lhs and rhs are equal to within the specified number 23 of significant digits. Both lhs and rhs can be floating-point 24 numbers, complex numbers, or input ranges of floating-point or 25 complex numbers. 26 --- 27 assert (matchDigits(0.1234567, 0.1234568)); 28 --- 29 30 Note that numbers which are very close to zero should normally be 31 compared using an absolute difference criterion. This can be 32 specified using an optional parameter, and is by default set to 1e-20. 33 Set it to zero if you do not want to use the absolute difference at all. 34 */ 35 bool matchDigits(L, R) 36 (L lhs, R rhs, uint significantDigits = 6, real maxAbsDiff = 1e-20) 37 if ((isFloatingPoint!L || is(L T == Complex!T)) && 38 (isFloatingPoint!R || is(R U == Complex!U))) 39 in 40 { 41 assert (maxAbsDiff >= 0); 42 static if (isFloatingPoint!L && isFloatingPoint!R) 43 { 44 assert (significantDigits*LOG2T < CommonType!(L, R).mant_dig, 45 "The requested precision is too high for the given type(s)"); 46 } 47 } 48 body 49 { 50 static if (is(L T == Complex!T)) 51 { 52 static if (is(R U == Complex!U)) 53 { 54 // lhs and rhs are complex 55 return matchDigits(lhs.re, rhs.re, significantDigits, maxAbsDiff) 56 && matchDigits(lhs.im, rhs.im, significantDigits, maxAbsDiff); 57 } 58 else 59 { 60 // lhs is complex, rhs is real 61 return matchDigits(lhs.re, rhs, significantDigits, maxAbsDiff) 62 && matchDigits(lhs.im, 0.0, significantDigits, maxAbsDiff); 63 } 64 } 65 else 66 { 67 static if (is(R U == Complex!U)) 68 { 69 // lhs is real, rhs is complex 70 return matchDigits(lhs, rhs.re, significantDigits, maxAbsDiff) 71 && matchDigits(0.0, rhs.im, significantDigits, maxAbsDiff); 72 } 73 else 74 { 75 // lhs and rhs are real 76 return 77 feqrel!(CommonType!(L, R))(lhs, rhs) > significantDigits*LOG2T 78 || abs(lhs - rhs) <= maxAbsDiff; 79 } 80 } 81 } 82 83 84 unittest 85 { 86 assert (matchDigits(1.0, 1.0, 15)); 87 assert (matchDigits(0.1234567, 0.1234568, 6)); 88 assert (!matchDigits(0.1234567, 0.1234568, 7)); 89 90 auto z = Complex!real(0.1234567, 0.0); 91 assert (matchDigits(z, 0.1234568, 6)); 92 assert (matchDigits(0.1234568, z, 6)); 93 94 auto u = Complex!real(0.1234567, 0.1234567); 95 auto v = Complex!real(0.1234568, 0.1234567); 96 auto w = Complex!real(0.1234567, 0.1234568); 97 assert (matchDigits(u, v, 6)); 98 assert (matchDigits(u, w, 6)); 99 assert (matchDigits(v, w, 6)); 100 } 101 102 103 104 105 /// ditto 106 bool matchDigits(L, R) 107 (L lhs, R rhs, uint significantDigits = 6, real maxAbsDiff = 1e-20) 108 if (isInputRange!L || isInputRange!R) 109 { 110 static if (isInputRange!L) 111 { 112 static if (isInputRange!R) 113 { 114 // Both lhs and rhs are ranges. 115 for (;; lhs.popFront(), rhs.popFront()) 116 { 117 if (lhs.empty) return rhs.empty; 118 if (rhs.empty) return lhs.empty; 119 if (!matchDigits(lhs.front, rhs.front, significantDigits, maxAbsDiff)) 120 return false; 121 } 122 } 123 else 124 { 125 // lhs is a range, rhs is a number. 126 for (; !lhs.empty; lhs.popFront()) 127 { 128 if (!matchDigits(lhs.front, rhs, significantDigits, maxAbsDiff)) 129 return false; 130 } 131 return true; 132 } 133 } 134 else 135 { 136 // lhs is a number, rhs is a range. 137 return matchDigits(rhs, lhs, significantDigits, maxAbsDiff); 138 } 139 } 140 141 142 unittest 143 { 144 assert (matchDigits([0.1234566, 0.1234567, 0.1234568], 0.1234567, 6)); 145 assert (matchDigits(0.1234567, [0.1234566, 0.1234567, 0.1234568], 6)); 146 assert (matchDigits([0.1234566, 0.1234567, 0.1234568], 147 [0.1234567, 0.1234567, 0.1234567], 6)); 148 } 149 150 151 152 153 /** Replaces real numbers that are close to zero by exactly zero. 154 --- 155 assert (chop(1.0) == 1.0); 156 assert (chop(1e-20) == 0.0); 157 --- 158 */ 159 Real chop(Real)(Real x, real threshold = 1e-10L) pure nothrow 160 if (isFloatingPoint!(Real)) 161 { 162 if (fabs(x) < threshold) return 0.0; 163 return x; 164 } 165 166 unittest 167 { 168 assert (chop(1.0) == 1.0); 169 assert (chop(1e-20) == 0.0); 170 } 171 172 173 /** Replaces all numbers in the given array that are close to zero 174 by exactly zero. To chop the array in-place, pass the same array 175 as both x and buffer. 176 --- 177 double[] a = [1.0, 1e-20, 2.0]; 178 double[] x = [1.0, 0.0, 2.0]; 179 auto b = chop(a); 180 assert (b == x); 181 chop(a, 1e-12L, a); 182 assert (a == x); 183 --- 184 */ 185 Real[] chop(Real) (Real[] x, real threshold = 1e-10L, Real[] buffer=null) 186 nothrow 187 if (isFloatingPoint!Real) 188 { 189 if (buffer.length < x.length) buffer.length = x.length; 190 191 foreach (i; 0 .. x.length) 192 if (fabs(x[i]) < threshold) buffer[i] = 0.0; 193 else buffer[i] = x[i]; 194 195 return buffer; 196 } 197 198 unittest 199 { 200 real[] a = [1.0L, real.min_normal, 2.0]; 201 real[] x = [1.0L, 0.0, 2.0]; 202 203 auto b = chop(a); 204 assert (b == x); 205 206 chop(a, 1e-12L, a); 207 assert (a == x); 208 } 209 210 211 212 213 /** Create a static array literal without any heap allocation. 214 215 staticArray() automatically deduces its type from the arguments, 216 while staticArrayOf() lets you specify the type explicitly. 217 */ 218 CommonType!(T)[T.length] staticArray(T...)(T elements) 219 @safe pure nothrow 220 if (!is(CommonType!T == void)) 221 { 222 // Inspired by code posted by David Simcha on the Phobos 223 // developers' mailing list. 224 typeof(return) a = void; 225 foreach (i, e; elements) a[i] = e; 226 return a; 227 } 228 229 230 /// ditto 231 T[U.length] staticArrayOf(T, U...)(U elements) 232 @safe pure nothrow 233 if (allConvertibleTo!(T, U)) 234 { 235 typeof(return) a = void; 236 foreach (i, e; elements) a[i] = e; 237 return a; 238 } 239 240 241 unittest 242 { 243 auto a = staticArray(0.0, 1.0, 2.0); 244 auto b = staticArrayOf!double(0.0F, 1.0, 2.0L); 245 double[3] c = [0.0, 1.0, 2.0]; 246 assert (a == c); 247 assert (b == c); 248 } 249 250 251 252 253 /** Returns a range that iterates over n equally-spaced floating-point 254 numbers in the inclusive interval [a,b]. 255 256 This is similar to std.range.iota, except that it allows 257 you to specify the number of steps it takes rather than the step size, 258 and that the last point is exactly equal to b (unless n = 1, in 259 which case a is the first and last point). This makes it 260 more useful than iota for iterating over floating-point numbers. 261 262 Example: 263 --- 264 int i = 0; 265 foreach (x; steps(0.0, 9.0, 10)) 266 { 267 assert (x == i); 268 ++i; 269 } 270 --- 271 */ 272 auto steps(T, U)(T a, U b, int n) 273 if (isFloatingPoint!T && isFloatingPoint!U) 274 { 275 alias CommonType!(T, U) V; 276 277 struct Steps 278 { 279 private: 280 int _i, _resolution; 281 V _delta, _start, _stop; 282 283 public: 284 @property bool empty() { return _i < 0; } 285 286 @property V front() 287 { 288 assert (!empty); 289 if (_i == _resolution - 1) return _start; 290 return _stop - _i*_delta; 291 } 292 293 void popFront() { 294 assert (!empty); 295 --_i; 296 } 297 } 298 299 Steps s; 300 s._start = a; 301 s._stop = b; 302 if (n > 1) s._delta = (b - a) / (n - 1); 303 s._i = n - 1; 304 s._resolution = n; 305 return s; 306 } 307 308 309 unittest 310 { 311 auto s1 = steps(-5.0, 5.0, 11); 312 assert (equal(s1, [-5.0, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5])); 313 314 auto s2 = steps(0.0, 1.0, 1); 315 assert (equal(s2, [0.0])); 316 317 // From the example 318 int i = 0; 319 foreach (x; steps(0.0, 9.0, 10)) 320 { 321 assert (x == i); 322 ++i; 323 } 324 } 325 326 327 328 329 /** Trivially convert const(T[]) to const(T)[]. */ 330 const(T)[] tailConst(T)(const T[] a) pure nothrow 331 { 332 return a; 333 } 334 335 336 337 338 /** Limit the number of times a function can be called. 339 340 Given a delegate dg and an integer maxCalls, this function 341 returns a functor with the same parameter and return types as dg, 342 and that forwards up to maxCalls calls to dg. On the maxCalls+1th 343 call it throws an exception. This is useful for algorithms such 344 as scid.nonlinear.findRoot() which may require an arbitrary 345 number of function calls to complete. 346 --- 347 void f(int i) { ... } 348 auto g = limitCalls(&f, 2); 349 350 g(0); // succeeds 351 g(0); // succeeds 352 g(0); // throws 353 --- 354 355 In the example above, when we're taking the address of f, its 356 context may be copied to the heap. This is a potentially expensive 357 operation. To avoid it, use the unsafe function scopeLimitCalls() 358 instead. This should only be used if you are absolutely sure that 359 the returned delegate never escapes the current scope. 360 --- 361 alias void delegate(int) DgType; 362 363 // Here, the delegate escapes by being assigned to a variable 364 // outside the function scope. 365 DgType someGlobal; 366 void badIdea1() 367 { 368 void f(int i) { ... } 369 someGlobal = scopeLimitCalls(&f, 10); 370 } 371 372 // Here it escapes because we return it. 373 DgType badIdea2() 374 { 375 void f(int i) { ... } 376 return scopeLimitCalls(&f, 10); 377 } 378 --- 379 */ 380 LimitCalls!(R, T) limitCalls(R, T...) 381 (R delegate(T) dg, uint maxCalls) 382 { 383 typeof(return) functor = void; 384 functor.dg = dg; 385 functor.maxCalls = maxCalls; 386 functor.calls = 0; 387 return functor; 388 } 389 390 /// ditto 391 LimitCalls!(R, T) scopeLimitCalls(R, T...) 392 (scope R delegate(T) dg, uint maxCalls) 393 { 394 typeof(return) functor = void; 395 functor.dg = dg; 396 functor.maxCalls = maxCalls; 397 functor.calls = 0; 398 return functor; 399 } 400 401 struct LimitCalls(R, T...) 402 { 403 private: 404 R delegate(T) dg; 405 uint maxCalls; 406 uint calls; 407 408 public: 409 R opCall(T params) 410 { 411 if (calls++ >= maxCalls) 412 throw new Exception("Function called too many times"); 413 return dg(params); 414 } 415 } 416 417 418 unittest 419 { 420 int sum; 421 void add(int i) { sum += i; } 422 423 auto f = limitCalls(&add, 2); 424 f(3); 425 f(1); 426 try { f(10); assert (false); } catch (Exception e) { assert (true); } 427 assert (sum == 4); 428 429 auto g = scopeLimitCalls(&add, 2); 430 g(2); 431 g(5); 432 try { g(10); assert (false); } catch (Exception e) { assert (true); } 433 assert (sum == 11); 434 }