1 ! Copyright (C) 2012 Doug Coleman.
2 ! See http://factorcode.org/license.txt for BSD license.
3 USING: accessors arrays combinators.short-circuit constructors
4 fry grouping kernel math math.vectors sequences sequences.deep
5 math.order parser assocs math.combinatorics ;
8 : flat? ( array -- ? ) [ sequence? ] any? not ; inline
10 GENERIC: array-replace ( object -- shape )
14 M: object array-replace drop f ;
16 M: sequence array-replace
23 TUPLE: uniform-shape shape ;
24 C: <uniform-shape> uniform-shape
26 TUPLE: abnormal-shape shape ;
27 C: <abnormal-shape> abnormal-shape
29 GENERIC: wrap-shape ( object -- shape )
32 1array <uniform-shape> ;
34 M: sequence wrap-shape
36 [ length ] [ first ] bi 2array <uniform-shape>
41 GENERIC: shape ( array -- shape )
43 M: sequence shape array-replace wrap-shape ;
45 : ndim ( array -- n ) shape length ;
47 ERROR: no-negative-shape-components shape ;
49 : check-shape-domain ( seq -- seq )
50 dup [ 0 < ] any? [ throw-no-negative-shape-components ] when ;
52 GENERIC: shape-capacity ( shape -- n )
54 M: sequence shape-capacity check-shape-domain product ;
56 M: uniform-shape shape-capacity
59 M: abnormal-shape shape-capacity
61 [ dup sequence? [ drop ] [ + ] if ] [ 1 + ] if*
64 ERROR: underlying-shape-mismatch underlying shape ;
66 ERROR: no-abnormally-shaped-arrays underlying shape ;
68 GENERIC: check-underlying-shape ( underlying shape -- underlying shape )
70 M: abnormal-shape check-underlying-shape
71 throw-no-abnormally-shaped-arrays ;
73 M: uniform-shape check-underlying-shape
74 shape>> check-underlying-shape ;
76 M: sequence check-underlying-shape
77 2dup [ length ] [ shape-capacity ] bi*
78 = [ throw-underlying-shape-mismatch ] unless ; inline
80 ERROR: shape-mismatch shaped0 shaped1 ;
82 : check-shape ( shaped-array shaped-array -- shaped-array shaped-array )
84 sequence= [ throw-shape-mismatch ] unless ;
86 TUPLE: shaped-array underlying shape ;
87 TUPLE: row-array < shaped-array ;
88 TUPLE: col-array < shaped-array ;
90 M: shaped-array length underlying>> length ; inline
92 M: shaped-array shape shape>> ;
94 : make-shaped-array ( underlying shape class -- shaped-array )
95 [ check-underlying-shape ] dip new
97 swap >>underlying ; inline
99 : <shaped-array> ( underlying shape -- shaped-array )
100 shaped-array make-shaped-array ; inline
102 : <row-array> ( underlying shape -- shaped-array )
103 row-array make-shaped-array ; inline
105 : <col-array> ( underlying shape -- shaped-array )
106 col-array make-shaped-array ; inline
108 GENERIC: >shaped-array ( array -- shaped-array )
109 GENERIC: >row-array ( array -- shaped-array )
110 GENERIC: >col-array ( array -- shaped-array )
112 M: sequence >shaped-array
113 [ { } flatten-as ] [ shape ] bi <shaped-array> ;
115 M: shaped-array >shaped-array ;
117 M: shaped-array >row-array
118 [ underlying>> ] [ shape>> ] bi <row-array> ;
120 M: shaped-array >col-array
121 [ underlying>> ] [ shape>> ] bi <col-array> ;
123 M: sequence >col-array
124 [ flatten ] [ shape ] bi <col-array> ;
126 : shaped+ ( a b -- c )
128 [ [ underlying>> ] bi@ v+ ]
129 [ drop shape>> clone ] 2bi shaped-array boa ;
131 : shaped-array>array ( shaped-array -- array )
132 [ underlying>> ] [ shape>> ] bi
136 [ rest-slice reverse [ group ] each ] unless-empty
139 : reshape ( shaped-array shape -- array )
140 check-underlying-shape >>shape ;
142 : shaped-like ( shaped-array shape -- array )
143 [ underlying>> clone ] dip <shaped-array> ;
145 : repeated-shaped ( shape element -- shaped-array )
146 [ [ shape-capacity ] dip <array> ]
147 [ drop 1 1 pad-head ] 2bi <shaped-array> ;
149 : zeros ( shape -- shaped-array ) 0 repeated-shaped ;
151 : ones ( shape -- shaped-array ) 1 repeated-shaped ;
153 : increasing ( shape -- shaped-array )
154 [ shape-capacity iota >array ] [ ] bi <shaped-array> ;
156 : decreasing ( shape -- shaped-array )
157 [ shape-capacity iota <reversed> >array ] [ ] bi <shaped-array> ;
159 : row-length ( shape -- n ) rest-slice product ; inline
161 : column-length ( shape -- n ) first ; inline
163 : each-row ( shaped-array quot -- )
164 [ [ underlying>> ] [ shape>> row-length <groups> ] bi ] dip
167 TUPLE: transposed shaped-array ;
169 : transposed-shape ( shaped-array -- shape )
172 TUPLE: row-traverser shaped-array index ;
174 GENERIC: next-index ( object -- index )
176 SYNTAX: sa{ \ } [ >shaped-array ] parse-literal ;
178 USE: prettyprint.custom
179 ! M: row-array pprint* shaped-array>array pprint* ;
180 ! M: col-array pprint* shaped-array>array flip pprint* ;
181 M: shaped-array pprint-delims drop \ sa{ \ } ;
182 M: shaped-array >pprint-sequence shaped-array>array ;
183 M: shaped-array pprint* pprint-object ;
184 M: shaped-array pprint-narrow? drop f ;
186 ERROR: shaped-bounds-error seq shape ;
188 : shaped-bounds-check ( seq shaped -- seq shaped )
189 2dup shape [ < ] 2all? [ shaped-bounds-error ] unless ;
192 : calculate-row-major-index ( seq shape -- i )
193 1 [ * ] accumulate nip reverse v* sum ;
195 : calculate-column-major-index ( seq shape -- i )
196 1 [ * ] accumulate nip v* sum ;
198 : set-shaped-row-major ( obj seq shaped -- )
199 shaped-bounds-check [ shape calculate-row-major-index ] [ underlying>> ] bi set-nth ;
201 : set-shaped-column-major ( obj seq shaped -- )
202 shaped-bounds-check [ shape calculate-column-major-index ] [ underlying>> ] bi set-nth ;
205 : 2d? ( shape -- ? ) length 2 = ;
206 ERROR: 2d-expected shaped ;
207 : check-2d ( shaped -- shaped ) dup shape>> 2d? [ 2d-expected ] unless ;
209 : diagonal? ( coord -- ? ) { [ 2d? ] [ first2 = ] } 1&& ;
211 ! : definite? ( sa -- ? )
213 : shaped-each ( .. sa quot -- )
214 [ underlying>> ] dip each ; inline
216 ! : set-shaped-where ( .. elt sa quot -- )
218 ! [ underlying>> [ length iota ] keep zip ]
220 ! ] dip '[ _ [ _ set- ] @ ] assoc-each ; inline
222 : shaped-map! ( .. sa quot -- sa )
223 '[ _ map ] change-underlying ; inline
225 : shaped-map ( .. sa quot -- sa' )
226 [ [ underlying>> ] dip map ]
227 [ drop shape>> ] 2bi <shaped-array> ; inline
229 : pad-shapes ( sa0 sa1 -- sa0' sa1' )
231 2dup longer length '[ _ 1 pad-head ] bi@
232 [ shaped-like ] bi-curry@ bi* ;
234 : output-shape ( sa0 sa1 -- shape )
236 [ 2dup [ zero? ] either? [ max ] [ 2drop 0 ] if ] 2map ;
238 : broadcast-shape-matches? ( sa broadcast-shape -- ? )
240 { [ drop 1 = ] [ = ] } 2||
243 : broadcastable? ( sa0 sa1 -- ? )
245 [ [ shape>> ] bi@ ] [ output-shape ] 2bi
246 '[ _ broadcast-shape-matches? ] both? ;
248 TUPLE: block-array shaped shape ;
250 : <block-array> ( underlying shape -- obj )
253 : iteration-indices ( shaped -- seq )
255 cartesian-product concat
256 [ dup first array? [ first2 suffix ] when ] map
259 : map-shaped-index ( shaped quot -- shaped )
261 [ [ underlying>> ] [ shape>> iteration-indices ] bi zip ] dip map
262 ] dip swap >>underlying ; inline
264 : identity-matrix ( n -- shaped )
265 dup 2array zeros [ second first2 = 1 0 ? ] map-shaped-index ;
267 : map-strict-lower ( shaped quot -- shaped )
269 '[ first2 first2 > _ when ] map-shaped-index ; inline
271 : map-lower ( shaped quot -- shaped )
273 '[ first2 first2 >= _ when ] map-shaped-index ; inline
275 : map-strict-upper ( shaped quot -- shaped )
277 '[ first2 first2 < _ when ] map-shaped-index ; inline
279 : map-upper ( shaped quot -- shaped )
281 '[ first2 first2 <= _ when ] map-shaped-index ; inline
283 : map-diagonal ( shaped quot -- shaped )
285 '[ first2 first2 = _ when ] map-shaped-index ; inline
287 : upper ( shape obj -- shaped )
288 [ zeros check-2d ] dip '[ drop _ ] map-upper ;
290 : strict-upper ( shape obj -- shaped )
291 [ zeros check-2d ] dip '[ drop _ ] map-strict-upper ;
293 : lower ( shape obj -- shaped )
294 [ zeros check-2d ] dip '[ drop _ ] map-lower ;
296 : strict-lower ( shape obj -- shaped )
297 [ zeros check-2d ] dip '[ drop _ ] map-strict-lower ;