! See http://factorcode.org/license.txt for BSD license.
USING: accessors arrays combinators.short-circuit constructors
fry grouping kernel math math.vectors sequences sequences.deep
-math.order parser ;
+math.order parser assocs math.combinatorics ;
IN: arrays.shaped
: flat? ( array -- ? ) [ sequence? ] any? not ; inline
M: uniform-shape check-underlying-shape
shape>> check-underlying-shape ;
-
+
M: sequence check-underlying-shape
2dup [ length ] [ shape-capacity ] bi*
= [ underlying-shape-mismatch ] unless ; inline
dup [ zero? ] any? [
2drop { }
] [
- [ rest-slice [ group ] each ] unless-empty
+ [ rest-slice reverse [ group ] each ] unless-empty
] if ;
: reshape ( shaped-array shape -- array )
M: shaped-array pprint* pprint-object ;
M: shaped-array pprint-narrow? drop f ;
+ERROR: shaped-bounds-error seq shape ;
+
+: shaped-bounds-check ( seq shaped -- seq shaped )
+ 2dup shape [ < ] 2all? [ shaped-bounds-error ] unless ;
+
+! Inefficient
+: calculate-row-major-index ( seq shape -- i )
+ 1 [ * ] accumulate nip reverse v* sum ;
+
+: calculate-column-major-index ( seq shape -- i )
+ 1 [ * ] accumulate nip v* sum ;
+
+: set-shaped-row-major ( obj seq shaped -- )
+ shaped-bounds-check [ shape calculate-row-major-index ] [ underlying>> ] bi set-nth ;
+
+: set-shaped-column-major ( obj seq shaped -- )
+ shaped-bounds-check [ shape calculate-column-major-index ] [ underlying>> ] bi set-nth ;
+
+! Matrices
+: 2d? ( shape -- ? ) length 2 = ;
+ERROR: 2d-expected shaped ;
+: check-2d ( shaped -- shaped ) dup shape>> 2d? [ 2d-expected ] unless ;
+
+: diagonal? ( coord -- ? ) { [ 2d? ] [ first2 = ] } 1&& ;
+
+! : definite? ( sa -- ? )
+
: shaped-each ( .. sa quot -- )
[ underlying>> ] dip each ; inline
+! : set-shaped-where ( .. elt sa quot -- )
+ ! [
+ ! [ underlying>> [ length iota ] keep zip ]
+ ! [ ] bi
+ ! ] dip '[ _ [ _ set- ] @ ] assoc-each ; inline
+
: shaped-map! ( .. sa quot -- sa )
'[ _ map ] change-underlying ; inline
pad-shapes
[ [ shape>> ] bi@ ] [ output-shape ] 2bi
'[ _ broadcast-shape-matches? ] both? ;
+
+TUPLE: block-array shaped shape ;
+
+: <block-array> ( underlying shape -- obj )
+ block-array boa ;
+
+: iteration-indices ( shaped -- seq )
+ [ iota ] [
+ cartesian-product concat
+ [ dup first array? [ first2 suffix ] when ] map
+ ] map-reduce ;
+
+: map-shaped-index ( shaped quot -- shaped )
+ over [
+ [ [ underlying>> ] [ shape>> iteration-indices ] bi zip ] dip map
+ ] dip swap >>underlying ; inline
+
+: identity-matrix ( n -- shaped )
+ dup 2array zeros [ second first2 = 1 0 ? ] map-shaped-index ;
+
+: map-strict-lower ( shaped quot -- shaped )
+ [ check-2d ] dip
+ '[ first2 first2 > _ when ] map-shaped-index ; inline
+
+: map-lower ( shaped quot -- shaped )
+ [ check-2d ] dip
+ '[ first2 first2 >= _ when ] map-shaped-index ; inline
+
+: map-strict-upper ( shaped quot -- shaped )
+ [ check-2d ] dip
+ '[ first2 first2 < _ when ] map-shaped-index ; inline
+
+: map-upper ( shaped quot -- shaped )
+ [ check-2d ] dip
+ '[ first2 first2 <= _ when ] map-shaped-index ; inline
+
+: map-diagonal ( shaped quot -- shaped )
+ [ check-2d ] dip
+ '[ first2 first2 = _ when ] map-shaped-index ; inline
+
+: upper ( shape obj -- shaped )
+ [ zeros check-2d ] dip '[ drop _ ] map-upper ;
+
+: strict-upper ( shape obj -- shaped )
+ [ zeros check-2d ] dip '[ drop _ ] map-strict-upper ;
+
+: lower ( shape obj -- shaped )
+ [ zeros check-2d ] dip '[ drop _ ] map-lower ;
+
+: strict-lower ( shape obj -- shaped )
+ [ zeros check-2d ] dip '[ drop _ ] map-strict-lower ;