-USING: accessors alien alien.c-types arrays byte-arrays combinators
-combinators.lib combinators.short-circuit fry kernel locals macros
-math math.blas.cblas math.blas.vectors math.blas.vectors.private
-math.complex math.functions math.order multi-methods qualified
-sequences sequences.merged sequences.private generalizations
-shuffle symbols speicalized-arrays.float specialized-arrays.double ;
-QUALIFIED: syntax
+USING: accessors alien alien.c-types alien.complex alien.data
+ascii byte-arrays combinators combinators.short-circuit functors
+kernel math math.blas.ffi math.blas.vectors
+math.blas.vectors.private parser prettyprint.custom sequences
+sequences.merged sequences.private specialized-arrays ;
+FROM: alien.c-types => float ;
+SPECIALIZED-ARRAY: float
+SPECIALIZED-ARRAY: double
+SPECIALIZED-ARRAY: complex-float
+SPECIALIZED-ARRAY: complex-double
IN: math.blas.matrices
-TUPLE: blas-matrix-base data ld rows cols transpose ;
-TUPLE: float-blas-matrix < blas-matrix-base ;
-TUPLE: double-blas-matrix < blas-matrix-base ;
-TUPLE: float-complex-blas-matrix < blas-matrix-base ;
-TUPLE: double-complex-blas-matrix < blas-matrix-base ;
-
-C: <float-blas-matrix> float-blas-matrix
-C: <double-blas-matrix> double-blas-matrix
-C: <float-complex-blas-matrix> float-complex-blas-matrix
-C: <double-complex-blas-matrix> double-complex-blas-matrix
-
-METHOD: element-type { float-blas-matrix }
- drop "float" ;
-METHOD: element-type { double-blas-matrix }
- drop "double" ;
-METHOD: element-type { float-complex-blas-matrix }
- drop "CBLAS_C" ;
-METHOD: element-type { double-complex-blas-matrix }
- drop "CBLAS_Z" ;
+TUPLE: blas-matrix-base underlying ld rows cols transpose ;
: Mtransposed? ( matrix -- ? )
transpose>> ; inline
: Mheight ( matrix -- height )
dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
+GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
+GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
+GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
+GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
+
<PRIVATE
: (blas-transpose) ( matrix -- integer )
- transpose>> [ CblasTrans ] [ CblasNoTrans ] if ;
+ transpose>> [ "T" ] [ "N" ] if ;
GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
-METHOD: (blas-matrix-like) { object object object object object float-blas-matrix }
- drop <float-blas-matrix> ;
-METHOD: (blas-matrix-like) { object object object object object double-blas-matrix }
- drop <double-blas-matrix> ;
-METHOD: (blas-matrix-like) { object object object object object float-complex-blas-matrix }
- drop <float-complex-blas-matrix> ;
-METHOD: (blas-matrix-like) { object object object object object double-complex-blas-matrix }
- drop <double-complex-blas-matrix> ;
-
-METHOD: (blas-matrix-like) { object object object object object float-blas-vector }
- drop <float-blas-matrix> ;
-METHOD: (blas-matrix-like) { object object object object object double-blas-vector }
- drop <double-blas-matrix> ;
-METHOD: (blas-matrix-like) { object object object object object float-complex-blas-vector }
- drop <float-complex-blas-matrix> ;
-METHOD: (blas-matrix-like) { object object object object object double-complex-blas-vector }
- drop <double-complex-blas-matrix> ;
-
-METHOD: (blas-vector-like) { object object object float-blas-matrix }
- drop <float-blas-vector> ;
-METHOD: (blas-vector-like) { object object object double-blas-matrix }
- drop <double-blas-vector> ;
-METHOD: (blas-vector-like) { object object object float-complex-blas-matrix }
- drop <float-complex-blas-vector> ;
-METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
- drop <double-complex-blas-vector> ;
-
: (validate-gemv) ( A x y -- )
{
[ drop [ Mwidth ] [ length>> ] bi* = ]
[ nip [ Mheight ] [ length>> ] bi* = ]
} 3&&
- [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] unless ;
+ [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
+ unless ;
-:: (prepare-gemv) ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc y )
+:: (prepare-gemv)
+ ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
+ y )
A x y (validate-gemv)
- CblasColMajor
A (blas-transpose)
A rows>>
A cols>>
- alpha >c-arg call
- A data>>
+ alpha
+ A
A ld>>
- x data>>
+ x
x inc>>
- beta >c-arg call
- y data>>
+ beta
+ y
y inc>>
y ; inline
: (validate-ger) ( x y A -- )
{
- [ nip [ length>> ] [ Mheight ] bi* = ]
- [ nipd [ length>> ] [ Mwidth ] bi* = ]
+ [ [ length>> ] [ drop ] [ Mheight ] tri* = ]
+ [ [ drop ] [ length>> ] [ Mwidth ] tri* = ]
} 3&&
- [ "Mismatched vertices and matrix in vector outer product" throw ] unless ;
+ [ "Mismatched vertices and matrix in vector outer product" throw ]
+ unless ;
-:: (prepare-ger) ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld A )
+:: (prepare-ger)
+ ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
+ A )
x y A (validate-ger)
- CblasColMajor
A rows>>
A cols>>
- alpha >c-arg call
- x data>>
+ alpha
+ x
x inc>>
- y data>>
+ y
y inc>>
- A data>>
+ A
A ld>>
A f >>transpose ; inline
: (validate-gemm) ( A B C -- )
{
- [ drop [ Mwidth ] [ Mheight ] bi* = ]
- [ nip [ Mheight ] bi@ = ]
- [ nipd [ Mwidth ] bi@ = ]
- } 3&& [ "Mismatched matrices in matrix multiplication" throw ] unless ;
+ [ [ Mwidth ] [ Mheight ] [ drop ] tri* = ]
+ [ [ Mheight ] [ drop ] [ Mheight ] tri* = ]
+ [ [ drop ] [ Mwidth ] [ Mwidth ] tri* = ]
+ } 3&&
+ [ "Mismatched matrices in matrix multiplication" throw ]
+ unless ;
-:: (prepare-gemm) ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld C )
+:: (prepare-gemm)
+ ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
+ C )
A B C (validate-gemm)
- CblasColMajor
A (blas-transpose)
B (blas-transpose)
C rows>>
C cols>>
A Mwidth
- alpha >c-arg call
- A data>>
+ alpha
+ A
A ld>>
- B data>>
+ B
B ld>>
- beta >c-arg call
- C data>>
+ beta
+ C
C ld>>
C f >>transpose ; inline
PRIVATE>
-: >float-blas-matrix ( arrays -- matrix )
- [ >float-array underlying>> ] (>matrix) <float-blas-matrix> ;
-: >double-blas-matrix ( arrays -- matrix )
- [ >double-array underlying>> ] (>matrix) <double-blas-matrix> ;
-: >float-complex-blas-matrix ( arrays -- matrix )
- [ (flatten-complex-sequence) >float-array underlying>> ] (>matrix)
- <float-complex-blas-matrix> ;
-: >double-complex-blas-matrix ( arrays -- matrix )
- [ (flatten-complex-sequence) >double-array underlying>> ] (>matrix)
- <double-complex-blas-matrix> ;
-
-GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
-GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
-GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
-GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
-
-METHOD: n*M.V+n*V! { real float-blas-matrix float-blas-vector real float-blas-vector }
- [ ] (prepare-gemv) [ cblas_sgemv ] dip ;
-METHOD: n*M.V+n*V! { real double-blas-matrix double-blas-vector real double-blas-vector }
- [ ] (prepare-gemv) [ cblas_dgemv ] dip ;
-METHOD: n*M.V+n*V! { number float-complex-blas-matrix float-complex-blas-vector number float-complex-blas-vector }
- [ (>c-complex) ] (prepare-gemv) [ cblas_cgemv ] dip ;
-METHOD: n*M.V+n*V! { number double-complex-blas-matrix double-complex-blas-vector number double-complex-blas-vector }
- [ (>z-complex) ] (prepare-gemv) [ cblas_zgemv ] dip ;
-
-METHOD: n*V(*)V+M! { real float-blas-vector float-blas-vector float-blas-matrix }
- [ ] (prepare-ger) [ cblas_sger ] dip ;
-METHOD: n*V(*)V+M! { real double-blas-vector double-blas-vector double-blas-matrix }
- [ ] (prepare-ger) [ cblas_dger ] dip ;
-METHOD: n*V(*)V+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
- [ (>c-complex) ] (prepare-ger) [ cblas_cgeru ] dip ;
-METHOD: n*V(*)V+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
- [ (>z-complex) ] (prepare-ger) [ cblas_zgeru ] dip ;
-
-METHOD: n*V(*)Vconj+M! { real float-blas-vector float-blas-vector float-blas-matrix }
- [ ] (prepare-ger) [ cblas_sger ] dip ;
-METHOD: n*V(*)Vconj+M! { real double-blas-vector double-blas-vector double-blas-matrix }
- [ ] (prepare-ger) [ cblas_dger ] dip ;
-METHOD: n*V(*)Vconj+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
- [ (>c-complex) ] (prepare-ger) [ cblas_cgerc ] dip ;
-METHOD: n*V(*)Vconj+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
- [ (>z-complex) ] (prepare-ger) [ cblas_zgerc ] dip ;
-
-METHOD: n*M.M+n*M! { real float-blas-matrix float-blas-matrix real float-blas-matrix }
- [ ] (prepare-gemm) [ cblas_sgemm ] dip ;
-METHOD: n*M.M+n*M! { real double-blas-matrix double-blas-matrix real double-blas-matrix }
- [ ] (prepare-gemm) [ cblas_dgemm ] dip ;
-METHOD: n*M.M+n*M! { number float-complex-blas-matrix float-complex-blas-matrix number float-complex-blas-matrix }
- [ (>c-complex) ] (prepare-gemm) [ cblas_cgemm ] dip ;
-METHOD: n*M.M+n*M! { number double-complex-blas-matrix double-complex-blas-matrix number double-complex-blas-matrix }
- [ (>z-complex) ] (prepare-gemm) [ cblas_zgemm ] dip ;
-
! XXX should do a dense clone
-syntax:M: blas-matrix-base clone
- [
- [
- { [ data>> ] [ ld>> ] [ cols>> ] [ element-type heap-size ] } cleave
- * * memory>byte-array
- ] [ { [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> ] } cleave ] bi
+M: blas-matrix-base clone
+ [
+ [ {
+ [ underlying>> ]
+ [ ld>> ]
+ [ cols>> ]
+ [ element-type heap-size ]
+ } cleave * * memory>byte-array ]
+ [ {
+ [ ld>> ]
+ [ rows>> ]
+ [ cols>> ]
+ [ transpose>> ]
+ } cleave ]
+ bi
] keep (blas-matrix-like) ;
! XXX try rounding stride to next 128 bit bound for better vectorizin'
: <empty-matrix> ( rows cols exemplar -- matrix )
- [ element-type [ * ] dip <c-array> ]
+ [ element-type heap-size * * <byte-array> ]
[ 2drop ]
- [ f swap (blas-matrix-like) ] 3tri ;
+ [ [ f ] dip (blas-matrix-like) ] 3tri ;
: n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
clone n*M.V+n*V! ;
n*M.V+n*V! ; inline
: M.V ( A x -- A.x )
- 1.0 -rot n*M.V ; inline
+ [ 1.0 ] 2dip n*M.V ; inline
: n*V(*)V ( alpha x y -- alpha*x(*)y )
2dup [ length>> ] bi@ pick <empty-matrix>
n*V(*)Vconj+M! ;
: V(*) ( x y -- x(*)y )
- 1.0 -rot n*V(*)V ; inline
+ [ 1.0 ] 2dip n*V(*)V ; inline
: V(*)conj ( x y -- x(*)yconj )
- 1.0 -rot n*V(*)Vconj ; inline
+ [ 1.0 ] 2dip n*V(*)Vconj ; inline
: n*M.M ( alpha A B -- alpha*A.B )
- 2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
- 1.0 swap n*M.M+n*M! ;
+ 2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
+ [ 1.0 ] dip n*M.M+n*M! ;
: M. ( A B -- A.B )
- 1.0 -rot n*M.M ; inline
+ [ 1.0 ] 2dip n*M.M ; inline
:: (Msub) ( matrix row col height width -- data ld rows cols )
matrix ld>> col * row + matrix element-type heap-size *
- matrix data>> <displaced-alien>
+ matrix underlying>> <displaced-alien>
matrix ld>>
height
width ;
-: Msub ( matrix row col height width -- sub )
- 5 npick dup transpose>>
- [ nip [ [ swap ] 2dip swap ] when (Msub) ] 2keep
- swap (blas-matrix-like) ;
+:: Msub ( matrix row col height width -- sub )
+ matrix dup transpose>>
+ [ col row width height ]
+ [ row col height width ] if (Msub)
+ matrix transpose>> matrix (blas-matrix-like) ;
-TUPLE: blas-matrix-rowcol-sequence parent inc rowcol-length rowcol-jump length ;
+TUPLE: blas-matrix-rowcol-sequence
+ parent inc rowcol-length rowcol-jump length ;
C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
INSTANCE: blas-matrix-rowcol-sequence sequence
-syntax:M: blas-matrix-rowcol-sequence length
+M: blas-matrix-rowcol-sequence length
length>> ;
-syntax:M: blas-matrix-rowcol-sequence nth-unsafe
+M: blas-matrix-rowcol-sequence nth-unsafe
{
[
[ rowcol-jump>> ]
[ parent>> element-type heap-size ]
- [ parent>> data>> ] tri
+ [ parent>> underlying>> ] tri
[ * * ] dip <displaced-alien>
]
[ rowcol-length>> ]
} cleave (blas-vector-like) ;
: (Mcols) ( A -- columns )
- { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } cleave
- <blas-matrix-rowcol-sequence> ;
+ { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
+ cleave <blas-matrix-rowcol-sequence> ;
: (Mrows) ( A -- rows )
- { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } cleave
- <blas-matrix-rowcol-sequence> ;
+ { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
+ cleave <blas-matrix-rowcol-sequence> ;
: Mrows ( A -- rows )
dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
recip swap n*M ; inline
: Mtranspose ( matrix -- matrix^T )
- [ { [ data>> ] [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> not ] } cleave ] keep (blas-matrix-like) ;
-
-syntax:M: blas-matrix-base equal?
+ [ {
+ [ underlying>> ]
+ [ ld>> ] [ rows>> ]
+ [ cols>> ]
+ [ transpose>> not ]
+ } cleave ] keep (blas-matrix-like) ;
+
+M: blas-matrix-base equal?
{
+ [ and ]
[ [ Mwidth ] bi@ = ]
[ [ Mcols ] bi@ [ = ] 2all? ]
} 2&& ;
+<<
+
+<FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
+
+VECTOR IS ${TYPE}-blas-vector
+<VECTOR> IS <${TYPE}-blas-vector>
+XGEMV IS ${T}GEMV
+XGEMM IS ${T}GEMM
+XGERU IS ${T}GER${U}
+XGERC IS ${T}GER${C}
+
+MATRIX DEFINES-CLASS ${TYPE}-blas-matrix
+<MATRIX> DEFINES <${TYPE}-blas-matrix>
+>MATRIX DEFINES >${TYPE}-blas-matrix
+
+t [ T >lower ]
+
+XMATRIX{ DEFINES ${t}matrix{
+
+WHERE
+
+TUPLE: MATRIX < blas-matrix-base ;
+: <MATRIX> ( underlying ld rows cols transpose -- matrix )
+ MATRIX boa ; inline
+
+M: MATRIX element-type
+ drop TYPE ;
+M: MATRIX (blas-matrix-like)
+ drop <MATRIX> ;
+M: VECTOR (blas-matrix-like)
+ drop <MATRIX> ;
+M: MATRIX (blas-vector-like)
+ drop <VECTOR> ;
+
+: >MATRIX ( arrays -- matrix )
+ [ TYPE >c-array underlying>> ] (>matrix) <MATRIX> ;
+
+M: VECTOR n*M.V+n*V!
+ (prepare-gemv) [ XGEMV ] dip ;
+M: MATRIX n*M.M+n*M!
+ (prepare-gemm) [ XGEMM ] dip ;
+M: MATRIX n*V(*)V+M!
+ (prepare-ger) [ XGERU ] dip ;
+M: MATRIX n*V(*)Vconj+M!
+ (prepare-ger) [ XGERC ] dip ;
+
+SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
+
+M: MATRIX pprint-delims
+ drop \ XMATRIX{ \ } ;
+
+;FUNCTOR>
+
+
+: define-real-blas-matrix ( TYPE T -- )
+ "" "" (define-blas-matrix) ;
+: define-complex-blas-matrix ( TYPE T -- )
+ "U" "C" (define-blas-matrix) ;
+
+float "S" define-real-blas-matrix
+double "D" define-real-blas-matrix
+complex-float "C" define-complex-blas-matrix
+complex-double "Z" define-complex-blas-matrix
+
+>>
+
+M: blas-matrix-base >pprint-sequence Mrows ;
+M: blas-matrix-base pprint* pprint-object ;