1 /** Various useful utilities that don't really fit anywhere else.
2 
3     Authors:    Lars Tandle Kyllingstad
4     Copyright:  Copyright (c) 2009-2010, Lars T. Kyllingstad. All rights reserved.
5     License:    Boost License 1.0
6 */
7 module scid.util;
8 
9 
10 import std.algorithm;
11 import std.complex;
12 import std.math;
13 import std.range;
14 import std.traits;
15 import std.typetuple: allSatisfy;
16 
17 import scid.core.traits;
18 
19 
20 
21 
22 /** Check whether lhs and rhs are equal to within the specified number
23     of significant digits.  Both lhs and rhs can be floating-point
24     numbers, complex numbers, or input ranges of floating-point or
25     complex numbers.
26     ---
27     assert (matchDigits(0.1234567, 0.1234568));
28     ---
29 
30     Note that numbers which are very close to zero should normally be
31     compared using an absolute difference criterion.  This can be
32     specified using an optional parameter, and is by default set to 1e-20.
33     Set it to zero if you do not want to use the absolute difference at all.
34 */
35 bool matchDigits(L, R)
36     (L lhs, R rhs, uint significantDigits = 6, real maxAbsDiff = 1e-20)
37     if ((isFloatingPoint!L || is(L T == Complex!T)) &&
38         (isFloatingPoint!R || is(R U == Complex!U)))
39 in
40 {
41     assert (maxAbsDiff >= 0);
42     static if (isFloatingPoint!L && isFloatingPoint!R)
43     {
44         assert (significantDigits*LOG2T < CommonType!(L, R).mant_dig,
45             "The requested precision is too high for the given type(s)");
46     }
47 }
48 body
49 {
50     static if (is(L T == Complex!T))
51     {
52         static if (is(R U == Complex!U))
53         {
54             // lhs and rhs are complex
55             return matchDigits(lhs.re, rhs.re, significantDigits, maxAbsDiff)
56                 && matchDigits(lhs.im, rhs.im, significantDigits, maxAbsDiff);
57         }
58         else
59         {
60             // lhs is complex, rhs is real
61             return matchDigits(lhs.re, rhs, significantDigits, maxAbsDiff)
62                 && matchDigits(lhs.im, 0.0, significantDigits, maxAbsDiff);
63         }
64     }
65     else
66     {
67         static if (is(R U == Complex!U))
68         {
69             // lhs is real, rhs is complex
70             return matchDigits(lhs, rhs.re, significantDigits, maxAbsDiff)
71                 && matchDigits(0.0, rhs.im, significantDigits, maxAbsDiff);
72         }
73         else
74         {
75             // lhs and rhs are real
76             return
77                 feqrel!(CommonType!(L, R))(lhs, rhs) > significantDigits*LOG2T
78              || abs(lhs - rhs) <= maxAbsDiff;
79         }
80     }
81 }
82 
83 
84 unittest
85 {
86     assert (matchDigits(1.0, 1.0, 15));
87     assert (matchDigits(0.1234567, 0.1234568, 6));
88     assert (!matchDigits(0.1234567, 0.1234568, 7));
89 
90     auto z = Complex!real(0.1234567, 0.0);
91     assert (matchDigits(z, 0.1234568, 6));
92     assert (matchDigits(0.1234568, z, 6));
93 
94     auto u = Complex!real(0.1234567, 0.1234567);
95     auto v = Complex!real(0.1234568, 0.1234567);
96     auto w = Complex!real(0.1234567, 0.1234568);
97     assert (matchDigits(u, v, 6));
98     assert (matchDigits(u, w, 6));
99     assert (matchDigits(v, w, 6));
100 }
101 
102 
103 
104 
105 /// ditto
106 bool matchDigits(L, R)
107     (L lhs, R rhs, uint significantDigits = 6, real maxAbsDiff = 1e-20)
108     if (isInputRange!L || isInputRange!R)
109 {
110     static if (isInputRange!L)
111     {
112         static if (isInputRange!R)
113         {
114             // Both lhs and rhs are ranges.
115             for (;; lhs.popFront(), rhs.popFront())
116             {
117                 if (lhs.empty) return rhs.empty;
118                 if (rhs.empty) return lhs.empty;
119                 if (!matchDigits(lhs.front, rhs.front, significantDigits, maxAbsDiff))
120                     return false;
121             }
122         }
123         else
124         {
125             // lhs is a range, rhs is a number.
126             for (; !lhs.empty; lhs.popFront())
127             {
128                 if (!matchDigits(lhs.front, rhs, significantDigits, maxAbsDiff))
129                     return false;
130             }
131             return true;
132         }
133     }
134     else
135     {
136         // lhs is a number, rhs is a range.
137         return matchDigits(rhs, lhs, significantDigits, maxAbsDiff);
138     }
139 }
140 
141 
142 unittest
143 {
144     assert (matchDigits([0.1234566, 0.1234567, 0.1234568], 0.1234567, 6));
145     assert (matchDigits(0.1234567, [0.1234566, 0.1234567, 0.1234568], 6));
146     assert (matchDigits([0.1234566, 0.1234567, 0.1234568], 
147         [0.1234567, 0.1234567, 0.1234567], 6));
148 }
149 
150 
151 
152 
153 /** Replaces real numbers that are close to zero by exactly zero.
154     ---
155     assert (chop(1.0) == 1.0);
156     assert (chop(1e-20) == 0.0);
157     ---
158 */
159 Real chop(Real)(Real x, real threshold = 1e-10L) pure nothrow
160     if (isFloatingPoint!(Real))
161 {
162     if (fabs(x) < threshold) return 0.0;
163     return x;
164 }
165 
166 unittest
167 {
168     assert (chop(1.0) == 1.0);
169     assert (chop(1e-20) == 0.0);
170 }
171 
172 
173 /** Replaces all numbers in the given array that are close to zero
174     by exactly zero. To chop the array in-place, pass the same array
175     as both x and buffer.
176     ---
177     double[] a = [1.0, 1e-20, 2.0];
178     double[] x = [1.0, 0.0, 2.0];
179     auto b = chop(a);
180     assert (b == x);
181     chop(a, 1e-12L, a);
182     assert (a == x);
183     ---
184 */
185 Real[] chop(Real) (Real[] x, real threshold = 1e-10L, Real[] buffer=null)
186     nothrow
187     if (isFloatingPoint!Real)
188 {
189     if (buffer.length < x.length) buffer.length = x.length;
190 
191     foreach (i; 0 .. x.length)
192         if (fabs(x[i]) < threshold)  buffer[i] = 0.0;
193         else  buffer[i] = x[i];
194 
195     return buffer;
196 }
197 
198 unittest
199 {
200     real[] a = [1.0L, real.min_normal, 2.0];
201     real[] x = [1.0L, 0.0, 2.0];
202 
203     auto b = chop(a);
204     assert (b == x);
205 
206     chop(a, 1e-12L, a);
207     assert (a == x);
208 }
209 
210 
211 
212 
213 /** Create a static array literal without any heap allocation.
214 
215     staticArray() automatically deduces its type from the arguments,
216     while staticArrayOf() lets you specify the type explicitly.
217 */
218 CommonType!(T)[T.length] staticArray(T...)(T elements)
219     @safe pure nothrow
220     if (!is(CommonType!T == void))
221 {
222     // Inspired by code posted by David Simcha on the Phobos
223     // developers' mailing list.
224     typeof(return) a = void;
225     foreach (i, e; elements)  a[i] = e;
226     return a;
227 }
228 
229 
230 /// ditto
231 T[U.length] staticArrayOf(T, U...)(U elements)
232     @safe pure nothrow
233     if (allConvertibleTo!(T, U))
234 {
235     typeof(return) a = void;
236     foreach (i, e; elements)  a[i] = e;
237     return a;
238 }
239 
240 
241 unittest
242 {
243     auto a = staticArray(0.0, 1.0, 2.0);
244     auto b = staticArrayOf!double(0.0F, 1.0, 2.0L);
245     double[3] c = [0.0, 1.0, 2.0];
246     assert (a == c);
247     assert (b == c);
248 }
249 
250 
251 
252 
253 /** Returns a range that iterates over n equally-spaced floating-point
254     numbers in the inclusive interval [a,b].
255 
256     This is similar to std.range.iota, except that it allows
257     you to specify the number of steps it takes rather than the step size,
258     and that the last point is exactly equal to b (unless n = 1, in
259     which case a is the first and last point).  This makes it
260     more useful than iota for iterating over floating-point numbers.
261 
262     Example:
263     ---
264     int i = 0;
265     foreach (x; steps(0.0, 9.0, 10))
266     {
267         assert (x == i);
268         ++i;
269     }
270     ---
271 */
272 auto steps(T, U)(T a, U b, int n)
273     if (isFloatingPoint!T && isFloatingPoint!U)
274 {
275     alias CommonType!(T, U) V;
276 
277     struct Steps
278     {
279     private:
280         int _i, _resolution;
281         V _delta, _start, _stop;
282 
283     public:
284         @property bool empty() { return _i < 0; }
285 
286         @property V front()
287         {
288             assert (!empty);
289             if (_i == _resolution - 1) return _start;
290             return _stop - _i*_delta;
291         }
292 
293         void popFront() {
294             assert (!empty);
295             --_i;
296         }
297     }
298 
299     Steps s;
300     s._start = a;
301     s._stop = b;
302     if (n > 1) s._delta = (b - a) / (n - 1);
303     s._i = n - 1;
304     s._resolution = n;
305     return s;
306 }
307 
308 
309 unittest
310 {
311     auto s1 = steps(-5.0, 5.0, 11);
312     assert (equal(s1, [-5.0, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]));
313 
314     auto s2 = steps(0.0, 1.0, 1);
315     assert (equal(s2, [0.0]));
316 
317     // From the example
318     int i = 0;
319     foreach (x; steps(0.0, 9.0, 10))
320     {
321         assert (x == i);
322         ++i;
323     }
324 }
325 
326 
327 
328 
329 /** Trivially convert const(T[]) to const(T)[]. */
330 const(T)[] tailConst(T)(const T[] a) pure nothrow
331 {
332     return a;
333 }
334 
335 
336 
337 
338 /** Limit the number of times a function can be called.
339 
340     Given a delegate dg and an integer maxCalls, this function
341     returns a functor with the same parameter and return types as dg,
342     and that forwards up to maxCalls calls to dg.  On the maxCalls+1th
343     call it throws an exception.  This is useful for algorithms such
344     as scid.nonlinear.findRoot() which may require an arbitrary
345     number of function calls to complete.
346     ---
347     void f(int i) { ... }
348     auto g = limitCalls(&f, 2);
349 
350     g(0);    // succeeds
351     g(0);    // succeeds
352     g(0);    // throws
353     ---
354 
355     In the example above, when we're taking the address of f, its
356     context may be copied to the heap.  This is a potentially expensive
357     operation.  To avoid it, use the unsafe function scopeLimitCalls()
358     instead.  This should only be used if you are absolutely sure that
359     the returned delegate never escapes the current scope.
360     ---
361     alias void delegate(int) DgType;
362 
363     // Here, the delegate escapes by being assigned to a variable
364     // outside the function scope.
365     DgType someGlobal;
366     void badIdea1()
367     {
368         void f(int i) { ... }
369         someGlobal = scopeLimitCalls(&f, 10);
370     }
371 
372     // Here it escapes because we return it.
373     DgType badIdea2()
374     {
375         void f(int i) { ... }
376         return scopeLimitCalls(&f, 10);
377     }
378     ---
379 */
380 LimitCalls!(R, T) limitCalls(R, T...)
381     (R delegate(T) dg, uint maxCalls)
382 {
383     typeof(return) functor = void;
384     functor.dg = dg;
385     functor.maxCalls = maxCalls;
386     functor.calls = 0;
387     return functor;
388 }
389 
390 /// ditto
391 LimitCalls!(R, T) scopeLimitCalls(R, T...)
392     (scope R delegate(T) dg, uint maxCalls)
393 {
394     typeof(return) functor = void;
395     functor.dg = dg;
396     functor.maxCalls = maxCalls;
397     functor.calls = 0;
398     return functor;
399 }
400 
401 struct LimitCalls(R, T...)
402 {
403 private:
404     R delegate(T) dg;
405     uint maxCalls;
406     uint calls;
407 
408 public:
409     R opCall(T params)
410     {
411         if (calls++ >= maxCalls)
412             throw new Exception("Function called too many times");
413         return dg(params);
414     }
415 }
416 
417 
418 unittest
419 {
420     int sum;
421     void add(int i) { sum += i; }
422 
423     auto f = limitCalls(&add, 2);
424     f(3);
425     f(1);
426     try { f(10); assert (false); } catch (Exception e) { assert (true); }
427     assert (sum == 4);
428 
429     auto g = scopeLimitCalls(&add, 2);
430     g(2);
431     g(5);
432     try { g(10); assert (false); } catch (Exception e) { assert (true); }
433     assert (sum == 11);
434 }