]> gitweb.factorcode.org Git - factor.git/blob - extra/arrays/shaped/shaped.factor
factor: trim more using lists.
[factor.git] / extra / arrays / shaped / shaped.factor
1 ! Copyright (C) 2012 Doug Coleman.
2 ! See http://factorcode.org/license.txt for BSD license.
3 USING: accessors arrays assocs combinators.short-circuit
4 grouping kernel math math.functions math.order math.vectors
5 parser prettyprint.custom sequences sequences.deep ;
6 IN: arrays.shaped
7
8 : flat? ( array -- ? ) [ sequence? ] none? ; inline
9
10 GENERIC: array-replace ( object -- shape )
11
12 M: f array-replace ;
13
14 M: object array-replace drop f ;
15
16 M: sequence array-replace
17     dup flat? [
18         length
19     ] [
20         [ array-replace ] map
21     ] if ;
22
23 TUPLE: uniform-shape shape ;
24 C: <uniform-shape> uniform-shape
25
26 TUPLE: abnormal-shape shape ;
27 C: <abnormal-shape> abnormal-shape
28
29 GENERIC: wrap-shape ( object -- shape )
30
31 M: integer wrap-shape
32     1array <uniform-shape> ;
33
34 M: sequence wrap-shape
35     dup all-equal? [
36         [ length ] [ first ] bi 2array <uniform-shape>
37     ] [
38         <abnormal-shape>
39     ] if ;
40
41 GENERIC: shape ( array -- shape )
42
43 M: sequence shape array-replace wrap-shape ;
44
45 : ndim ( array -- n ) shape length ;
46
47 ERROR: no-negative-shape-components shape ;
48
49 : check-shape-domain ( seq -- seq )
50     dup [ 0 < ] any? [ no-negative-shape-components ] when ;
51
52 GENERIC: shape-capacity ( shape -- n )
53
54 M: sequence shape-capacity check-shape-domain product ;
55
56 M: uniform-shape shape-capacity
57     shape>> product ;
58
59 M: abnormal-shape shape-capacity
60     shape>> 0 swap [
61         [ dup sequence? [ drop ] [ + ] if ] [ 1 + ] if*
62     ] deep-each ;
63
64 ERROR: underlying-shape-mismatch underlying shape ;
65
66 ERROR: no-abnormally-shaped-arrays underlying shape ;
67
68 GENERIC: check-underlying-shape ( underlying shape -- underlying shape )
69
70 M: abnormal-shape check-underlying-shape
71     no-abnormally-shaped-arrays ;
72
73 M: uniform-shape check-underlying-shape
74     shape>> check-underlying-shape ;
75
76 M: sequence check-underlying-shape
77     2dup [ length ] [ shape-capacity ] bi*
78     = [ underlying-shape-mismatch ] unless ; inline
79
80 ERROR: shape-mismatch shaped0 shaped1 ;
81
82 DEFER: >shaped-array
83
84 : check-shape ( shaped-array shaped-array -- shaped-array shaped-array )
85     [ >shaped-array ] bi@
86     2dup [ shape>> ] bi@
87     sequence= [ shape-mismatch ] unless ;
88
89 TUPLE: shaped-array underlying shape ;
90 TUPLE: row-array < shaped-array ;
91 TUPLE: col-array < shaped-array ;
92
93 M: shaped-array length underlying>> length ; inline
94
95 M: shaped-array shape shape>> ;
96
97 : make-shaped-array ( underlying shape class -- shaped-array )
98     [ check-underlying-shape ] dip new
99         swap >>shape
100         swap >>underlying ; inline
101
102 : <shaped-array> ( underlying shape -- shaped-array )
103     shaped-array make-shaped-array ; inline
104
105 : <row-array> ( underlying shape -- shaped-array )
106     row-array make-shaped-array ; inline
107
108 : <col-array> ( underlying shape -- shaped-array )
109     col-array make-shaped-array ; inline
110
111 GENERIC: >shaped-array ( array -- shaped-array )
112 GENERIC: >row-array ( array -- shaped-array )
113 GENERIC: >col-array ( array -- shaped-array )
114
115 M: sequence >shaped-array
116     [ { } flatten-as ] [ shape ] bi <shaped-array> ;
117
118 M: shaped-array >shaped-array ;
119
120 M: shaped-array >row-array
121     [ underlying>> ] [ shape>> ] bi <row-array> ;
122
123 M: shaped-array >col-array
124     [ underlying>> ] [ shape>> ] bi <col-array> ;
125
126 M: sequence >col-array
127     [ flatten ] [ shape ] bi <col-array> ;
128
129 : shaped-unary-op ( shaped quot -- )
130     [ >shaped-array ] dip
131     [ underlying>> ] prepose
132     [ shape>> clone ] bi shaped-array boa ; inline
133
134 : shaped-shaped-binary-op ( shaped0 shaped1 quot -- c )
135     [ check-shape ] dip
136     [ [ underlying>> ] bi@ ] prepose
137     [ drop shape>> clone ] 2bi shaped-array boa ; inline
138
139 : shaped+ ( a b -- c ) [ v+ ] shaped-shaped-binary-op ;
140 : shaped- ( a b -- c ) [ v- ] shaped-shaped-binary-op ;
141 : shaped*. ( a b -- c ) [ v* ] shaped-shaped-binary-op ;
142
143 : shaped*n ( a b -- c ) [ v*n ] curry shaped-unary-op ;
144 : n*shaped ( a b -- c ) swap shaped*n ;
145
146 : shaped-cos ( a -- b ) [ [ cos ] map ] shaped-unary-op ;
147 : shaped-sin ( a -- b ) [ [ sin ] map ] shaped-unary-op ;
148
149 : shaped-array>array ( shaped-array -- array )
150     [ underlying>> ] [ shape>> ] bi
151     dup [ zero? ] any? [
152         2drop { }
153     ] [
154         [ rest-slice reverse [ group ] each ] unless-empty
155     ] if ;
156
157 : reshape ( shaped-array shape -- array )
158     check-underlying-shape
159     [ >shaped-array ] dip >>shape ;
160
161 : shaped-like ( shaped-array shape -- array )
162     [ underlying>> clone ] dip <shaped-array> ;
163
164 : repeated-shaped ( shape element -- shaped-array )
165     [ [ shape-capacity ] dip <array> ]
166     [ drop 1 1 pad-head ] 2bi <shaped-array> ;
167
168 : zeros ( shape -- shaped-array ) 0 repeated-shaped ;
169
170 : ones ( shape -- shaped-array ) 1 repeated-shaped ;
171
172 : increasing ( shape -- shaped-array )
173     [ shape-capacity <iota> >array ] [ ] bi <shaped-array> ;
174
175 : decreasing ( shape -- shaped-array )
176     [ shape-capacity <iota> <reversed> >array ] [ ] bi <shaped-array> ;
177
178 : row-length ( shape -- n ) rest-slice product ; inline
179
180 : column-length ( shape -- n ) first ; inline
181
182 : each-row ( shaped-array quot -- )
183     [ [ underlying>> ] [ shape>> row-length <groups> ] bi ] dip
184     each ; inline
185
186 TUPLE: transposed shaped-array ;
187
188 : transposed-shape ( shaped-array -- shape )
189     shape>> <reversed> ;
190
191 TUPLE: row-traverser shaped-array index ;
192
193 GENERIC: next-index ( object -- index )
194
195 SYNTAX: sa{ \ } [ >shaped-array ] parse-literal ;
196
197 ! M: row-array pprint* shaped-array>array pprint* ;
198 ! M: col-array pprint* shaped-array>array flip pprint* ;
199 M: shaped-array pprint-delims drop \ sa{ \ } ;
200 M: shaped-array >pprint-sequence shaped-array>array ;
201 M: shaped-array pprint* pprint-object ;
202 M: shaped-array pprint-narrow? drop f ;
203
204 ERROR: shaped-bounds-error seq shape ;
205
206 : shaped-bounds-check ( seq shaped -- seq shaped )
207     2dup shape [ < ] 2all? [ shaped-bounds-error ] unless ;
208
209 ! Inefficient
210 : calculate-row-major-index ( seq shape -- i )
211     1 [ * ] accumulate nip reverse vdot ;
212
213 : calculate-column-major-index ( seq shape -- i )
214     1 [ * ] accumulate nip vdot ;
215
216 : get-shaped-row-major ( seq shaped -- elt )
217     shaped-bounds-check [ shape calculate-row-major-index ] [ underlying>> ] bi nth ;
218
219 : set-shaped-row-major ( obj seq shaped -- )
220     shaped-bounds-check [ shape calculate-row-major-index ] [ underlying>> ] bi set-nth ;
221
222 : get-shaped-column-major ( seq shaped -- elt )
223     shaped-bounds-check [ shape calculate-column-major-index ] [ underlying>> ] bi nth ;
224
225 : set-shaped-column-major ( obj seq shaped -- )
226     shaped-bounds-check [ shape calculate-column-major-index ] [ underlying>> ] bi set-nth ;
227
228 ! Matrices
229 : 2d? ( shape -- ? ) length 2 = ;
230 ERROR: 2d-expected shaped ;
231 : check-2d ( shaped -- shaped ) dup shape>> 2d? [ 2d-expected ] unless ;
232
233 : diagonal? ( coord -- ? ) { [ 2d? ] [ first2 = ] } 1&& ;
234
235 ! : definite? ( sa -- ? )
236
237 : shaped-each ( .. sa quot -- )
238     [ underlying>> ] dip each ; inline
239
240 ! : set-shaped-where ( .. elt sa quot -- )
241     ! [
242         ! [ underlying>> [ length <iota> ] keep zip ]
243         ! [ ] bi
244     ! ] dip '[ _ [ _ set- ] @ ] assoc-each ; inline
245
246 : shaped-map! ( .. sa quot -- sa )
247     '[ _ map ] change-underlying ; inline
248
249 : shaped-map ( .. sa quot -- sa' )
250     [ [ underlying>> ] dip map ]
251     [ drop shape>> ] 2bi <shaped-array> ; inline
252
253 : pad-shapes ( sa0 sa1 -- sa0' sa1' )
254     2dup [ shape>> ] bi@
255     2dup longer length '[ _ 1 pad-head ] bi@
256     [ shaped-like ] bi-curry@ bi* ;
257
258 : output-shape ( sa0 sa1 -- shape )
259     [ shape>> ] bi@
260     [ 2dup [ zero? ] either? [ max ] [ 2drop 0 ] if ] 2map ;
261
262 : broadcast-shape-matches? ( sa broadcast-shape -- ? )
263     [
264         { [ drop 1 = ] [ = ] } 2||
265     ] 2all? ;
266
267 : broadcastable? ( sa0 sa1 -- ? )
268     pad-shapes
269     [ [ shape>> ] bi@ ] [ output-shape ] 2bi
270     '[ _ broadcast-shape-matches? ] both? ;
271
272 TUPLE: block-array shaped shape ;
273
274 : <block-array> ( underlying shape -- obj )
275     block-array boa ;
276
277 : iteration-indices ( shaped -- seq )
278     [ <iota> ] [
279         cartesian-product concat
280         [ dup first array? [ first2 suffix ] when ] map
281     ] map-reduce ;
282
283 : map-shaped-index ( shaped quot -- shaped )
284     over [
285         [ [ underlying>> ] [ shape>> iteration-indices ] bi zip ] dip map
286     ] dip swap >>underlying ; inline
287
288 : identity-matrix ( n -- shaped )
289     dup 2array zeros [ second first2 = 1 0 ? ] map-shaped-index ;
290
291 : map-strict-lower ( shaped quot -- shaped )
292     [ check-2d ] dip
293     '[ first2 first2 > _ when ] map-shaped-index ; inline
294
295 : map-lower ( shaped quot -- shaped )
296     [ check-2d ] dip
297     '[ first2 first2 >= _ when ] map-shaped-index ; inline
298
299 : map-strict-upper ( shaped quot -- shaped )
300     [ check-2d ] dip
301     '[ first2 first2 < _ when ] map-shaped-index ; inline
302
303 : map-upper ( shaped quot -- shaped )
304     [ check-2d ] dip
305     '[ first2 first2 <= _ when ] map-shaped-index ; inline
306
307 : map-diagonal ( shaped quot -- shaped )
308     [ check-2d ] dip
309     '[ first2 first2 = _ when ] map-shaped-index ; inline
310
311 : upper ( shape obj -- shaped )
312     [ zeros check-2d ] dip '[ drop _ ] map-upper ;
313
314 : strict-upper ( shape obj -- shaped )
315     [ zeros check-2d ] dip '[ drop _ ] map-strict-upper ;
316
317 : lower ( shape obj -- shaped )
318     [ zeros check-2d ] dip '[ drop _ ] map-lower ;
319
320 : strict-lower ( shape obj -- shaped )
321     [ zeros check-2d ] dip '[ drop _ ] map-strict-lower ;