! Copyright (C) 2012 Doug Coleman.
! See http://factorcode.org/license.txt for BSD license.
-USING: accessors arrays constructors grouping kernel math
-sequences math.vectors sequences.deep ;
+USING: accessors arrays combinators.short-circuit constructors
+fry grouping kernel math math.vectors sequences sequences.deep
+math.order parser ;
IN: arrays.shaped
: flat? ( array -- ? ) [ sequence? ] any? not ; inline
GENERIC: wrap-shape ( object -- shape )
M: integer wrap-shape
- 0 2array <uniform-shape> ;
+ 1array <uniform-shape> ;
M: sequence wrap-shape
dup all-equal? [
M: sequence shape array-replace wrap-shape ;
+: ndim ( array -- n ) shape length ;
+
+ERROR: no-negative-shape-components shape ;
+
+: check-shape-domain ( seq -- seq )
+ dup [ 0 < ] any? [ no-negative-shape-components ] when ;
+
GENERIC: shape-capacity ( shape -- n )
-M: sequence shape-capacity product ;
+M: sequence shape-capacity check-shape-domain product ;
-M: uniform-shape shape-capacity shape>> product ;
+M: uniform-shape shape-capacity
+ shape>> product ;
M: abnormal-shape shape-capacity
shape>> 0 swap [
GENERIC: >col-array ( array -- shaped-array )
M: sequence >shaped-array
- [ flatten ] [ shape ] bi <shaped-array> ;
+ [ { } flatten-as ] [ shape ] bi <shaped-array> ;
M: shaped-array >shaped-array ;
[ drop shape>> clone ] 2bi shaped-array boa ;
: shaped-array>array ( shaped-array -- array )
- [ underlying>> ] [ shape>> ] bi rest-slice [ group ] each ;
+ [ underlying>> ] [ shape>> ] bi
+ dup [ zero? ] any? [
+ 2drop { }
+ ] [
+ [ rest-slice [ group ] each ] unless-empty
+ ] if ;
: reshape ( shaped-array shape -- array )
check-underlying-shape >>shape ;
[ underlying>> clone ] dip <shaped-array> ;
: repeated-shaped ( shape element -- shaped-array )
- [ [ shape-capacity ] dip <array> ] [ drop ] 2bi <shaped-array> ;
+ [ [ shape-capacity ] dip <array> ]
+ [ drop 1 1 pad-head ] 2bi <shaped-array> ;
: zeros ( shape -- shaped-array ) 0 repeated-shaped ;
GENERIC: next-index ( object -- index )
+SYNTAX: sa{ \ } [ >shaped-array ] parse-literal ;
+
USE: prettyprint.custom
-M: shaped-array pprint* shaped-array>array pprint* ;
-M: row-array pprint* shaped-array>array pprint* ;
-M: col-array pprint* shaped-array>array flip pprint* ;
+! M: row-array pprint* shaped-array>array pprint* ;
+! M: col-array pprint* shaped-array>array flip pprint* ;
+M: shaped-array pprint-delims drop \ sa{ \ } ;
+M: shaped-array >pprint-sequence shaped-array>array ;
+M: shaped-array pprint* pprint-object ;
+M: shaped-array pprint-narrow? drop f ;
+
+: shaped-each ( .. sa quot -- )
+ [ underlying>> ] dip each ; inline
+
+: shaped-map! ( .. sa quot -- sa )
+ '[ _ map ] change-underlying ; inline
+
+: shaped-map ( .. sa quot -- sa' )
+ [ [ underlying>> ] dip map ]
+ [ drop shape>> ] 2bi <shaped-array> ; inline
+
+: pad-shapes ( sa0 sa1 -- sa0' sa1' )
+ 2dup [ shape>> ] bi@
+ 2dup longer length '[ _ 1 pad-head ] bi@
+ [ shaped-like ] bi-curry@ bi* ;
+
+: output-shape ( sa0 sa1 -- shape )
+ [ shape>> ] bi@
+ [ 2dup [ zero? ] either? [ max ] [ 2drop 0 ] if ] 2map ;
+
+: broadcast-shape-matches? ( sa broadcast-shape -- ? )
+ [
+ { [ drop 1 = ] [ = ] } 2||
+ ] 2all? ;
+
+: broadcastable? ( sa0 sa1 -- ? )
+ pad-shapes
+ [ [ shape>> ] bi@ ] [ output-shape ] 2bi
+ '[ _ broadcast-shape-matches? ] both? ;