1 ! Copyright (C) 2005, 2010, 2018, 2020 Slava Pestov, Joe Groff, and Cat Stevens.
2 ! See http://factorcode.org/license.txt for BSD license.
3 USING: accessors arrays classes.singleton columns combinators
4 combinators.short-circuit combinators.smart formatting fry
6 grouping kernel locals math math.bits math.functions math.order
7 math.private math.ranges math.statistics math.vectors
8 math.vectors.private sequences sequences.deep sequences.private
9 slots.private summary ;
11 grouping kernel kernel.private locals math math.bits
12 math.functions math.order math.private math.ranges
13 math.statistics math.vectors math.vectors.private sequences
14 sequences.deep sequences.extras sequences.private slots.private
16 >>>>>>> 0ac3067a26 (vnorm rename)
19 ! defined here because of issue #1943
20 DEFER: regular-matrix?
21 : regular-matrix? ( object -- ? )
23 dup first-unsafe length
27 ! the MRO (class linearization) is performed in the order the predicates appear here
28 ! except that null-matrix is last (but it is relied upon by zero-matrix)
30 ! sequence > matrix > zero-matrix > square-matrix > zero-square-matrix > null-matrix
32 ! Factor bug that's hard to repro: using `bi and` in these predicates
33 ! instead of 1&& will cause spirious no-method and bounds-error errors in <square-cols>
34 ! and the tests/docs for no apparent reason
35 PREDICATE: matrix < sequence
36 { [ [ sequence? ] all? ] [ regular-matrix? ] } 1&& ;
38 PREDICATE: irregular-matrix < sequence
39 { [ [ sequence? ] all? ] [ regular-matrix? not ] } 1&& ;
42 ! can't define dim using this predicate for this reason,
43 ! unless we are going to write two versions of dim, one of which is generic
44 PREDICATE: square-matrix < matrix
45 dimension first2-unsafe = ;
47 PREDICATE: null-matrix < matrix
50 PREDICATE: zero-matrix < matrix
51 dup null-matrix? [ drop f ] [ flatten [ zero? ] all? ] if ;
53 PREDICATE: zero-square-matrix < square-matrix
54 { [ zero-matrix? ] [ square-matrix? ] } 1&& ;
56 ! Benign matrix constructors
57 : <matrix> ( m n element -- matrix )
58 '[ _ _ <array> ] replicate ; inline
60 : <matrix-by> ( m n quot: ( ... -- elt ) -- matrix )
61 '[ _ _ replicate ] replicate ; inline
63 : <matrix-by-indices> ( ... m n quot: ( ... m' n' -- ... elt ) -- ... matrix )
64 [ [ <iota> ] bi@ ] dip cartesian-map ; inline
66 : <zero-matrix> ( m n -- matrix )
69 : <zero-square-matrix> ( n -- matrix )
70 dup <zero-matrix> ; inline
73 : (nth-from-end) ( n seq -- n )
74 length 1 - swap - ; inline flushable
76 : nth-end ( n seq -- elt )
77 [ (nth-from-end) ] keep nth ; inline flushable
79 : nth-end-unsafe ( n seq -- elt )
80 [ (nth-from-end) ] keep nth-unsafe ; inline flushable
82 : array-nth-end-unsafe ( n seq -- elt )
83 [ (nth-from-end) ] keep swap 2 fixnum+fast slot ; inline flushable
85 : set-nth-end ( elt n seq -- )
86 [ (nth-from-end) ] keep set-nth ; inline
88 : set-nth-end-unsafe ( elt n seq -- )
89 [ (nth-from-end) ] keep set-nth-unsafe ; inline
92 ! main-diagonal matrix
93 : <diagonal-matrix> ( diagonal-seq -- matrix )
94 [ length <zero-square-matrix> ] keep over
95 '[ dup _ nth set-nth-unsafe ] each-index ; inline
97 ! could also be written slower as: <diagonal-matrix> [ reverse ] map
98 : <anti-diagonal-matrix> ( diagonal-seq -- matrix )
99 [ length <zero-square-matrix> ] keep over
100 '[ dup _ nth set-nth-end-unsafe ] each-index ; inline
102 : <identity-matrix> ( n -- matrix )
103 1 <repetition> <diagonal-matrix> ; inline
105 : <eye> ( m n k z -- matrix )
106 [ [ <iota> ] bi@ ] 2dip
108 cartesian-map ; inline
110 ! if m = n and k = 0 then <identity-matrix> is (possibly) more efficient
111 :: <simple-eye> ( m n k -- matrix )
113 [ n <identity-matrix> ]
114 [ m n k 1 <eye> ] if ; inline
116 : <coordinate-matrix> ( dim -- coordinates )
117 first2 [ <iota> ] bi@ cartesian-product ; inline
119 ALIAS: <cartesian-indices> <coordinate-matrix>
121 : <cartesian-square-indices> ( n -- matrix )
122 dup 2array <cartesian-indices> ; inline
124 ALIAS: transpose flip
127 : array-matrix? ( matrix -- ? )
129 [ [ array? ] all? ] bi and ; inline foldable flushable
131 : matrix-cols-iota ( matrix -- cols-iota )
132 first-unsafe length <iota> ; inline
134 : unshaped-cols-iota ( matrix -- cols-iota )
135 [ first-unsafe length 1 ] keep
136 [ length min ] (each) (each-integer) <iota> ; inline
138 : generic-anti-transpose-unsafe ( cols-iota matrix -- newmatrix )
139 [ <reversed> [ nth-end-unsafe ] with { } map-as ] curry { } map-as ; inline
141 : array-anti-transpose-unsafe ( cols-iota matrix -- newmatrix )
142 [ <reversed> [ array-nth-end-unsafe ] with map ] curry map ; inline
145 ! much faster than [ reverse ] map flip [ reverse ] map
146 : anti-transpose ( matrix -- newmatrix )
148 [ dup regular-matrix?
149 [ matrix-cols-iota ] [ unshaped-cols-iota ] if
153 array-anti-transpose-unsafe
155 generic-anti-transpose-unsafe
159 ALIAS: anti-flip anti-transpose
161 : row ( n matrix -- row )
164 : rows ( seq matrix -- rows )
165 '[ _ row ] map ; inline
167 : col ( n matrix -- col )
168 swap '[ _ swap nth ] map ; inline
170 : cols ( seq matrix -- cols )
171 '[ _ col ] map ; inline
173 :: >square-matrix ( m -- subset )
174 m dimension first2 :> ( x y ) {
176 { [ x y < ] [ x <iota> m cols transpose ] }
177 { [ x y > ] [ y <iota> m rows ] }
180 GENERIC: <square-rows> ( desc -- matrix )
181 M: integer <square-rows>
182 <iota> <square-rows> ;
183 M: sequence <square-rows>
184 [ length ] keep >array '[ _ clone ] { } replicate-as ;
186 M: square-matrix <square-rows> ;
187 M: matrix <square-rows> >square-matrix ; ! coercing to square is more useful than no-method
189 GENERIC: <square-cols> ( desc -- matrix )
190 M: integer <square-cols>
191 <iota> <square-cols> ;
192 M: sequence <square-cols>
195 M: square-matrix <square-cols> ;
196 M: matrix <square-cols>
199 <PRIVATE ! implementation details of <lower-matrix> and <upper-matrix>
200 : dimension-range ( matrix -- dim range )
201 dimension [ <coordinate-matrix> ] [ first [1,b] ] bi ;
203 : upper-matrix-indices ( matrix -- matrix' )
204 dimension-range <reversed> [ tail-slice* >array ] 2map concat ;
206 : lower-matrix-indices ( matrix -- matrix' )
207 dimension-range [ head-slice >array ] 2map concat ;
211 DEFER: matrix-set-nths
212 : <lower-matrix> ( object m n -- matrix )
213 <zero-matrix> [ lower-matrix-indices ] [ matrix-set-nths ] [ ] tri ;
215 : <upper-matrix> ( object m n -- matrix )
216 <zero-matrix> [ upper-matrix-indices ] [ matrix-set-nths ] [ ] tri ;
218 ! element- and sequence-wise operations, getters and setters
220 [ ] [ [ append ] 2map ] map-reduce ;
222 : matrix-map ( matrix quot: ( ... elt -- ... elt' ) -- matrix' )
223 '[ _ map ] map ; inline
225 : matrix-map-index ( matrix quot: ( ... elt i j -- ... elt' ) -- matrix' )
226 '[ [ swap @ ] curry map-index ] map-index ; inline
228 : column-map ( matrix quot: ( ... col -- ... col' ) -- matrix' )
229 [ transpose ] dip map transpose ; inline
231 : matrix-nth ( pair matrix -- elt )
232 [ first2 swap ] dip nth nth ; inline
234 : matrix-nths ( pairs matrix -- elts )
235 '[ _ matrix-nth ] map ; inline
237 : matrix-set-nth ( obj pair matrix -- )
238 [ first2 swap ] dip nth set-nth ; inline
240 : matrix-set-nths ( obj pairs matrix -- )
241 '[ _ matrix-set-nth ] with each ; inline
243 ! -------------------------------------------
244 ! simple math of matrices follows
245 : mneg ( m -- m' ) [ vneg ] map ;
246 : mabs ( m -- m' ) [ vabs ] map ;
248 : n+m ( n m -- m ) [ n+v ] with map ;
249 : m+n ( m n -- m ) [ v+n ] curry map ;
250 : n-m ( n m -- m ) [ n-v ] with map ;
251 : m-n ( m n -- m ) [ v-n ] curry map ;
252 : n*m ( n m -- m ) [ n*v ] with map ;
253 : m*n ( m n -- m ) [ v*n ] curry map ;
254 : n/m ( n m -- m ) [ n/v ] with map ;
255 : m/n ( m n -- m ) [ v/n ] curry map ;
257 : m+ ( m1 m2 -- m ) [ v+ ] 2map ;
258 : m- ( m1 m2 -- m ) [ v- ] 2map ;
259 : m* ( m1 m2 -- m ) [ v* ] 2map ;
260 : m/ ( m1 m2 -- m ) [ v/ ] 2map ;
262 : vdotm ( v m -- p ) flip [ vdot ] with map ;
263 : mdotv ( m v -- p ) [ vdot ] curry map ;
264 : mdot ( m m -- m ) flip [ swap mdotv ] curry map ;
266 : m~ ( m1 m2 epsilon -- ? ) [ v~ ] curry 2all? ;
268 : mmin ( m -- n ) [ 1/0. ] dip [ [ min ] each ] each ;
269 : mmax ( m -- n ) [ -1/0. ] dip [ [ max ] each ] each ;
271 : matrix-l-infinity-norm ( m -- n )
272 dup zero-matrix? [ drop 0 ] [
273 [ [ abs ] map-sum ] map supremum
274 ] if ; inline foldable
276 : matrix-l1-norm ( m -- n )
277 dup zero-matrix? [ drop 0 ] [
278 flip matrix-l-infinity-norm
279 ] if ; inline foldable
281 : matrix-l2-norm ( m -- n )
282 dup zero-matrix? [ drop 0 ] [
283 [ [ sq ] map-sum ] map-sum sqrt
284 ] if ; inline foldable
286 M: zero-matrix l1-norm drop 0 ; inline
287 M: matrix l1-norm matrix-l1-norm ; inline
289 M: zero-matrix l2-norm drop 0 ; inline
290 M: matrix l2-norm matrix-l2-norm ; inline
292 M: zero-matrix l-infinity-norm drop 0 ; inline
293 M: matrix l-infinity-norm matrix-l-infinity-norm ; inline
295 ALIAS: frobenius-norm matrix-l2-norm
296 ALIAS: hilbert-schmidt-norm matrix-l2-norm
298 :: matrix-p-q-norm ( m p q -- n )
299 m dup zero-matrix? [ drop 0 ] [
300 [ [ sq ] map-sum q p / ^ ] map-sum q recip ^
301 ] if ; inline foldable
303 : matrix-p-norm-entrywise ( m p -- n )
304 [ flatten1 V{ } like ] dip p-norm-default ; inline
306 M: zero-matrix p-norm-default 2drop 0 ; inline
307 M: matrix p-norm-default matrix-p-norm-entrywise ; inline
309 : matrix-p-norm ( m p -- n )
310 over zero-matrix? [ 2drop 0 ] [
312 { [ dup 1 number= ] [ drop matrix-l1-norm ] }
313 { [ dup 2 number= ] [ drop matrix-l2-norm ] }
314 { [ dup fp-infinity? ] [ drop matrix-l-infinity-norm ] }
315 [ matrix-p-norm-entrywise ]
317 ] if ; inline foldable
319 M: zero-matrix p-norm 2drop 0 ; inline
320 M: matrix p-norm matrix-p-norm ; inline
322 : matrix-normalize ( m -- m' )
325 ] unless ; inline foldable
327 ! well-defined for square matrices; but works on nonsquare too
328 : main-diagonal ( matrix -- seq )
329 >square-matrix [ swap nth-unsafe ] map-index ; inline
331 ! top right to bottom left; reverse the result if you expected it to start in the lower left
332 : anti-diagonal ( matrix -- seq )
333 >square-matrix [ swap nth-end-unsafe ] map-index ; inline
336 : (rows-iota) ( matrix -- rows-iota )
337 dimension first-unsafe <iota> ;
338 : (cols-iota) ( matrix -- cols-iota )
339 dimension second-unsafe <iota> ;
341 : simple-rows-except ( matrix desc quot -- others )
342 curry [ dup (rows-iota) ] dip
343 pick reject-as swap rows ; inline
345 : simple-cols-except ( matrix desc quot -- others )
346 curry [ dup (cols-iota) ] dip
347 pick reject-as swap cols transpose ; inline ! need to un-transpose the result of cols
349 CONSTANT: scalar-except-quot [ = ]
350 CONSTANT: sequence-except-quot [ member? ]
353 GENERIC: rows-except ( matrix desc -- others )
354 M: integer rows-except scalar-except-quot simple-rows-except ;
355 M: sequence rows-except sequence-except-quot simple-rows-except ;
357 GENERIC: cols-except ( matrix desc -- others )
358 M: integer cols-except scalar-except-quot simple-cols-except ;
359 M: sequence cols-except sequence-except-quot simple-cols-except ;
361 ! well-defined for any regular matrix
362 : matrix-except ( matrix exclude-pair -- submatrix )
363 first2 [ rows-except ] dip cols-except ;
365 ALIAS: submatrix-excluding matrix-except
367 :: matrix-except-all ( matrix -- submatrices )
368 matrix dimension [ <iota> ] map first2-unsafe cartesian-product
369 [ [ matrix swap matrix-except ] map ] map ;
371 ALIAS: all-submatrices matrix-except-all
373 : dimension ( matrix -- dimension )
375 [ [ length ] [ first-unsafe length ] bi 2array ] if-empty ;