]> gitweb.factorcode.org Git - factor.git/blob - extra/math/extras/extras.factor
core: rename ensure-non-negative to assert-non-negative
[factor.git] / extra / math / extras / extras.factor
1 ! Copyright (C) 2012 John Benediktsson
2 ! See https://factorcode.org/license.txt for BSD license
3
4 USING: accessors arrays assocs byte-arrays combinators
5 combinators.short-circuit compression.zlib grouping kernel
6 kernel.private math math.bitwise math.combinatorics
7 math.constants math.functions math.order math.primes
8 math.primes.factors math.statistics math.vectors namespaces
9 random random.private ranges ranges.private sequences
10 sequences.extras sequences.private sets sorting sorting.extras ;
11
12 IN: math.extras
13
14 DEFER: stirling
15
16 <PRIVATE
17
18 : (stirling) ( n k -- x )
19     [ [ 1 - ] bi@ stirling ]
20     [ [ 1 - ] dip stirling ]
21     [ nip * + ] 2tri ;
22
23 PRIVATE>
24
25 MEMO: stirling ( n k -- x )
26     2dup { [ = ] [ nip 1 = ] } 2||
27     [ 2drop 1 ] [ (stirling) ] if ;
28
29 :: ramanujan ( x -- y )
30     pi sqrt x e / x ^ * x 8 * 4 + x * 1 + x * 1/30 + 1/6 ^ * ;
31
32 DEFER: bernoulli
33
34 <PRIVATE
35
36 : (bernoulli) ( p -- n )
37     [ <iota> ] [ 1 + ] bi [
38         0 [ [ nCk ] [ bernoulli * ] bi + ] with reduce
39     ] keep recip neg * ;
40
41 PRIVATE>
42
43 MEMO: bernoulli ( p -- n )
44     [ 1 ] [ (bernoulli) ] if-zero ;
45
46 ! From page 4 https://arxiv.org/ftp/arxiv/papers/2201/2201.12601.pdf
47 : bernoulli-estimate-factorial ( n -- n! )
48     [ 2pi swap ^ ] [ bernoulli ] bi * 2 / ;
49
50 : chi2 ( actual expected -- n )
51     0 [ dup 0 > [ [ - sq ] keep / + ] [ 2drop ] if ] 2reduce ;
52
53 <PRIVATE
54
55 : df-check ( df -- )
56     even? [ "odd degrees of freedom" throw ] unless ;
57
58 : (chi2P) ( chi/2 df/2 -- p )
59     [1..b) dupd n/v cum-product swap neg e^ [ v*n sum ] keep + ;
60
61 PRIVATE>
62
63 : chi2P ( chi df -- p )
64     dup df-check [ 2.0 / ] [ 2 /i ] bi* (chi2P) 1.0 min ;
65
66 <PRIVATE
67
68 : check-jacobi ( m -- m )
69     dup { [ integer? ] [ 0 > ] [ odd? ] } 1&&
70     [ "modulus must be odd positive integer" throw ] unless ;
71
72 : mod' ( x y -- n )
73     [ mod ] keep over zero? [ drop ] [
74         2dup [ sgn ] same? [ drop ] [ + ] if
75     ] if ;
76
77 PRIVATE>
78
79 : jacobi ( a m -- n )
80     check-jacobi [ mod' ] keep 1
81     [ pick zero? ] [
82         [ pick even? ] [
83             [ 2 / ] 2dip
84             over 8 mod' { 3 5 } member? [ neg ] when
85         ] while swapd
86         2over [ 4 mod' 3 = ] both? [ neg ] when
87         [ [ mod' ] keep ] dip
88     ] until [ nip 1 = ] dip 0 ? ;
89
90 <PRIVATE
91
92 : check-legendere ( m -- m )
93     dup prime? [ "modulus must be prime positive integer" throw ] unless ;
94
95 PRIVATE>
96
97 : legendere ( a m -- n )
98     check-legendere jacobi ;
99
100 : moving-average ( seq n -- newseq )
101     <clumps> [ mean ] map ;
102
103 : exponential-moving-average ( seq a -- newseq )
104     [ 1 ] 2dip '[ dupd swap - _ * + dup ] map nip ;
105
106 : moving-median ( u n -- v )
107     <clumps> [ median ] map ;
108
109 : moving-supremum ( u n -- v )
110     <clumps> [ supremum ] map ;
111
112 : moving-infimum ( u n -- v )
113     <clumps> [ infimum ] map ;
114
115 : moving-sum ( u n -- v )
116     <clumps> [ sum ] map ;
117
118 : moving-count ( ... u n quot: ( ... elt -- ... ? ) -- ... v )
119     [ <clumps> ] [ '[ _ count ] map ] bi* ; inline
120
121 : nonzero ( seq -- seq' )
122     [ zero? ] reject ;
123
124 : bartlett ( n -- seq )
125     dup 1 <= [ 1 = [ 1 1array ] [ { } ] if ] [
126         [ <iota> ] [ 1 - 2 / ] bi [
127             [ recip * ] [ >= ] 2bi [ 2 swap - ] when
128         ] curry map
129     ] if ;
130
131 : [0,2pi] ( n -- seq )
132     [ <iota> ] [ 1 - 2pi swap / ] bi v*n ;
133
134 : hanning ( n -- seq )
135     dup 1 <= [ 1 = [ 1 1array ] [ { } ] if ] [
136         [0,2pi] [ cos -0.5 * 0.5 + ] map!
137     ] if ;
138
139 : hamming ( n -- seq )
140     dup 1 <= [ 1 = [ 1 1array ] [ { } ] if ] [
141         [0,2pi] [ cos -0.46 * 0.54 + ] map!
142     ] if ;
143
144 : blackman ( n -- seq )
145     dup 1 <= [ 1 = [ 1 1array ] [ { } ] if ] [
146         [0,2pi] [
147             [ cos -0.5 * ] [ 2 * cos 0.08 * ] bi + 0.42 +
148         ] map
149     ] if ;
150
151 : nan-sum ( seq -- n )
152     0 [ dup fp-nan? [ drop ] [ + ] if ] binary-reduce ;
153
154 : nan-min ( seq -- n )
155     [ fp-nan? ] reject infimum ;
156
157 : nan-max ( seq -- n )
158     [ fp-nan? ] reject supremum ;
159
160 : fill-nans ( seq -- newseq )
161     [ first ] keep [
162         dup fp-nan? [ drop dup ] [ nip dup ] if
163     ] map nip ;
164
165 : sinc ( x -- y )
166     [ 1 ] [ pi * [ sin ] [ / ] bi ] if-zero ;
167
168 : cum-reduce ( seq identity quot: ( prev elt -- next ) -- result cum-result )
169     [ dup rot ] dip dup '[ _ curry dip dupd @ ] each ; inline
170
171 <PRIVATE
172
173 :: (gini) ( seq -- x )
174     seq sort :> sorted
175     seq length :> len
176     sorted 0 [ + ] cum-reduce :> ( a b )
177     b len a * / :> c
178     1 len recip + 2 c * - ;
179
180 PRIVATE>
181
182 : gini ( seq -- x )
183     dup length 1 <= [ drop 0 ] [ (gini) ] if ;
184
185 : concentration-coefficient ( seq -- x )
186     dup length 1 <= [
187         drop 0
188     ] [
189         [ (gini) ] [ length [ ] [ 1 - ] bi / ] bi *
190     ] if ;
191
192 : herfindahl ( seq -- x )
193     [ sum-of-squares ] [ sum sq ] bi / ;
194
195 : normalized-herfindahl ( seq -- x )
196     [ herfindahl ] [ length recip ] bi
197     [ - ] [ 1 swap - / ] bi ;
198
199 : exponential-index ( seq -- x )
200     dup sum '[ _ / dup ^ ] map-product ;
201
202 : weighted-random ( histogram -- obj )
203     unzip cum-sum [ last >float random ] keep bisect-left swap nth ;
204
205 : weighted-randoms ( length histogram -- seq )
206     unzip cum-sum swap
207     [ [ last >float random-generator get ] keep ] dip
208     '[ _ _ random* _ bisect-left _ nth ] replicate ;
209
210 : unique-indices ( seq -- unique indices )
211     [ members ] keep over dup length <iota>
212     H{ } zip-as '[ _ at ] map ;
213
214 : digitize] ( seq bins -- seq' )
215     '[ _ bisect-left ] map ;
216
217 : digitize) ( seq bins -- seq' )
218     '[ _ bisect-right ] map ;
219
220 <PRIVATE
221
222 : steps ( a b length -- a b step )
223     [ 2dup swap - ] dip / ; inline
224
225 PRIVATE>
226
227 : linspace[a..b) ( a b length -- seq )
228     steps ..b) <range> ;
229
230 : linspace[a..b] ( a b length -- seq )
231     {
232         { [ dup 1 < ] [ 3drop { } ] }
233         { [ dup 1 = ] [ 2drop 1array ] }
234         [ 1 - steps <range> ]
235     } cond ;
236
237 : logspace[a..b) ( a b length base -- seq )
238     [ linspace[a..b) ] dip swap n^v ;
239
240 : logspace[a..b] ( a b length base -- seq )
241     [ linspace[a..b] ] dip swap n^v ;
242
243 : majority ( seq -- elt/f )
244     [ f 0 ] dip [
245         over zero? [ 2nip 1 ] [
246             pick = [ 1 + ] [ 1 - ] if
247         ] if
248     ] each zero? [ drop f ] when ;
249
250 : compression-lengths ( a b -- len(a+b) len(a) len(b) )
251     [ append ] 2keep [ >byte-array compress data>> length ] tri@ ;
252
253 : compression-distance ( a b -- n )
254     compression-lengths sort-pair [ - ] [ / ] bi* ;
255
256 : compression-dissimilarity ( a b -- n )
257     compression-lengths + / ;
258
259 : round-to-decimal ( x n -- y )
260     10^ [ * 0.5 over 0 > [ + ] [ - ] if truncate ] [ / ] bi ;
261
262 : round-to-step ( x step -- y )
263     [ [ / round ] [ * ] bi ] unless-zero ;
264
265 GENERIC: round-away-from-zero ( x -- y )
266
267 M: integer round-away-from-zero ; inline
268
269 M: real round-away-from-zero
270     dup 0 < [ floor ] [ ceiling ] if ;
271
272 : monotonic-count ( seq quot: ( elt1 elt2 -- ? ) -- newseq )
273     over empty? [ 2drop { } ] [
274         [ 0 swap unclip-slice swap ] dip '[
275             [ @ [ 1 + ] [ drop 0 ] if ] keep over
276         ] { } map-as 2nip 0 prefix
277     ] if ; inline
278
279 : max-monotonic-count ( seq quot: ( elt1 elt2 -- ? ) -- n )
280     over empty? [ 2drop 0 ] [
281         [ 0 swap unclip-slice swap 0 ] dip '[
282             [ swapd @ [ 1 + ] [ max 0 ] if ] keep swap
283         ] reduce nip max
284     ] if ; inline
285
286 <PRIVATE
287
288 : kahan+ ( c sum elt -- c' sum' )
289     rot - 2dup + [ -rot [ - ] bi@ ] keep ; inline
290
291 PRIVATE>
292
293 : kahan-sum ( seq -- n )
294     [ 0.0 0.0 ] dip [ kahan+ ] each nip ;
295
296 : map-kahan-sum ( ... seq quot: ( ... elt -- ... n ) -- ... n )
297     [ 0.0 0.0 ] 2dip [ 2dip rot kahan+ ] curry
298     [ -rot ] prepose each nip ; inline
299
300 <PRIVATE
301
302 ! Adaptive Precision Floating-Point Arithmetic and Fast Robust Geometric Predicates
303 ! www-2.cs.cmu.edu/afs/cs/project/quake/public/papers/robust-arithmetic.ps
304
305 : sort-partial ( x y -- x' y' )
306     2dup [ abs ] bi@ < [ swap ] when ; inline
307
308 :: partial+ ( x y -- hi lo )
309     x y + dup x - :> yr y yr - ; inline
310
311 :: partial-sums ( seq -- seq' )
312     V{ } clone :> partials
313     seq [
314         0 partials [
315             swapd sort-partial partial+ swapd
316             [ over partials set-nth 1 + ] unless-zero
317         ] each :> i
318         i partials shorten
319         [ i partials set-nth ] unless-zero
320     ] each partials ;
321
322 :: sum-exact ( partials -- n )
323     partials [ 0.0 ] [
324         ! sum from the top, stop when sum becomes inexact
325         [ 0.0 0.0 ] dip [
326             nip partial+ dup 0.0 = not
327         ] find-last drop :> ( lo n )
328
329         ! make half-even rounding work across multiple partials
330         n [ 0 > ] [ f ] if* [
331             n 1 - partials nth
332             [ 0.0 < lo 0.0 < and ]
333             [ 0.0 > lo 0.0 > and ] bi or [
334                 lo 2.0 * :> y
335                 dup y + :> x
336                 x over - :> yr
337                 y yr = [ drop x ] when
338             ] when
339         ] when
340     ] if-empty ;
341
342 PRIVATE>
343
344 : sum-floats ( seq -- n )
345     partial-sums sum-exact ;
346
347 : mobius ( n -- x )
348     group-factors values [ 1 ] [
349         dup [ 1 > ] any?
350         [ drop 0 ] [ length even? 1 -1 ? ] if
351     ] if-empty ;
352
353 : kelly ( winning-probability odds -- fraction )
354     [ 1 + * 1 - ] [ / ] bi ;
355
356 :: integer-sqrt ( m -- n )
357     m [ 0 ] [
358         assert-non-negative
359         bit-length 1 - 2 /i :> c
360         1 :> a!
361         0 :> d!
362         c bit-length <iota> <reversed> [| s |
363             d :> e
364             c s neg shift d!
365             a d e - 1 - shift
366             m 2 c * e - d - 1 + neg shift a /i + a!
367         ] each
368         a a sq m > [ 1 - ] when
369     ] if-zero ;
370
371 <PRIVATE
372
373 : reduce-evens ( value u v -- value' u' v' )
374     [ 2dup [ even? ] both? ] [ [ 2 * ] [ 2/ ] [ 2/ ] tri* ] while ;
375
376 : reduce-odds ( value u v -- value' u' v' )
377     [
378         [ [ dup even? ] [ 2/ ] while ] bi@
379         2dup <=> {
380             { +eq+ [ over '[ _ * ] 2dip f ] }
381             { +lt+ [ swap [ - ] keep t ] }
382             { +gt+ [ [ - ] keep t ] }
383         } case
384     ] loop ;
385
386 PRIVATE>
387
388 : stein ( u v -- w )
389     2dup [ zero? ] both? [ "gcd for zeros is undefined" throw ] when
390     [ dup 0 < [ neg ] when ] bi@
391     [ 1 ] 2dip reduce-evens reduce-odds 2drop ;
392
393 TUPLE: vose
394     { n fixnum }
395     { items array }
396     { probs array }
397     { alias array } ;
398
399 :: <vose> ( dist -- vose )
400     V{ } clone :> small
401     V{ } clone :> large
402     dist assoc-size :> n
403     n f <array> :> alias
404
405     dist unzip dup [ length ] [ sum ] bi / v*n :> ( items probs )
406     probs [ swap 1 < small large ? push ] each-index
407
408     [ small empty? large empty? or ] [
409         small pop :> s
410         large pop :> l
411         l s alias set-nth
412         l dup probs [ s probs nth + 1 - dup ] change-nth
413         1 < small large ? push
414     ] until
415
416     1 large [ probs set-nth ] with each
417     1 small [ probs set-nth ] with each
418
419     n items probs alias vose boa ;
420
421 M:: vose random* ( obj rnd -- elt )
422     obj n>> rnd random* { fixnum } declare
423     dup obj probs>> nth-unsafe rnd (random-unit) >=
424     [ obj alias>> nth-unsafe ] unless
425     obj items>> nth-unsafe ;