-USING: accessors alien alien.c-types arrays byte-arrays combinators
-combinators.short-circuit fry kernel locals macros
-math math.blas.cblas math.blas.vectors math.blas.vectors.private
-math.complex math.functions math.order functors words
-sequences sequences.merged sequences.private shuffle
-specialized-arrays.direct.float specialized-arrays.direct.double
-specialized-arrays.float specialized-arrays.double ;
+USING: accessors alien alien.c-types alien.complex alien.data
+ascii byte-arrays combinators combinators.short-circuit functors
+kernel math math.blas.ffi math.blas.vectors
+math.blas.vectors.private parser prettyprint.custom sequences
+sequences.merged sequences.private specialized-arrays ;
+FROM: alien.c-types => float ;
+SPECIALIZED-ARRAY: float
+SPECIALIZED-ARRAY: double
+SPECIALIZED-ARRAY: complex-float
+SPECIALIZED-ARRAY: complex-double
IN: math.blas.matrices
TUPLE: blas-matrix-base underlying ld rows cols transpose ;
<PRIVATE
: (blas-transpose) ( matrix -- integer )
- transpose>> [ CblasTrans ] [ CblasNoTrans ] if ;
+ transpose>> [ "T" ] [ "N" ] if ;
GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
unless ;
:: (prepare-gemv)
- ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
- y )
+ ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
+ y )
A x y (validate-gemv)
- CblasColMajor
A (blas-transpose)
A rows>>
A cols>>
- alpha >c-arg call
- A underlying>>
+ alpha
+ A
A ld>>
- x underlying>>
+ x
x inc>>
- beta >c-arg call
- y underlying>>
+ beta
+ y
y inc>>
y ; inline
: (validate-ger) ( x y A -- )
{
- [ nip [ length>> ] [ Mheight ] bi* = ]
- [ nipd [ length>> ] [ Mwidth ] bi* = ]
+ [ [ length>> ] [ drop ] [ Mheight ] tri* = ]
+ [ [ drop ] [ length>> ] [ Mwidth ] tri* = ]
} 3&&
[ "Mismatched vertices and matrix in vector outer product" throw ]
unless ;
:: (prepare-ger)
- ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld
- A )
+ ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
+ A )
x y A (validate-ger)
- CblasColMajor
A rows>>
A cols>>
- alpha >c-arg call
- x underlying>>
+ alpha
+ x
x inc>>
- y underlying>>
+ y
y inc>>
- A underlying>>
+ A
A ld>>
A f >>transpose ; inline
: (validate-gemm) ( A B C -- )
{
- [ drop [ Mwidth ] [ Mheight ] bi* = ]
- [ nip [ Mheight ] bi@ = ]
- [ nipd [ Mwidth ] bi@ = ]
+ [ [ Mwidth ] [ Mheight ] [ drop ] tri* = ]
+ [ [ Mheight ] [ drop ] [ Mheight ] tri* = ]
+ [ [ drop ] [ Mwidth ] [ Mwidth ] tri* = ]
} 3&&
[ "Mismatched matrices in matrix multiplication" throw ]
unless ;
:: (prepare-gemm)
- ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
- C )
+ ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
+ C )
A B C (validate-gemm)
- CblasColMajor
A (blas-transpose)
B (blas-transpose)
C rows>>
C cols>>
A Mwidth
- alpha >c-arg call
- A underlying>>
+ alpha
+ A
A ld>>
- B underlying>>
+ B
B ld>>
- beta >c-arg call
- C underlying>>
+ beta
+ C
C ld>>
C f >>transpose ; inline
! XXX should do a dense clone
M: blas-matrix-base clone
- [
+ [
[ {
[ underlying>> ]
[ ld>> ]
! XXX try rounding stride to next 128 bit bound for better vectorizin'
: <empty-matrix> ( rows cols exemplar -- matrix )
- [ element-type [ * ] dip <c-array> ]
+ [ element-type heap-size * * <byte-array> ]
[ 2drop ]
- [ f swap (blas-matrix-like) ] 3tri ;
+ [ [ f ] dip (blas-matrix-like) ] 3tri ;
: n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
clone n*M.V+n*V! ;
n*M.V+n*V! ; inline
: M.V ( A x -- A.x )
- 1.0 -rot n*M.V ; inline
+ [ 1.0 ] 2dip n*M.V ; inline
: n*V(*)V ( alpha x y -- alpha*x(*)y )
2dup [ length>> ] bi@ pick <empty-matrix>
n*V(*)Vconj+M! ;
: V(*) ( x y -- x(*)y )
- 1.0 -rot n*V(*)V ; inline
+ [ 1.0 ] 2dip n*V(*)V ; inline
: V(*)conj ( x y -- x(*)yconj )
- 1.0 -rot n*V(*)Vconj ; inline
+ [ 1.0 ] 2dip n*V(*)Vconj ; inline
: n*M.M ( alpha A B -- alpha*A.B )
- 2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
- 1.0 swap n*M.M+n*M! ;
+ 2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
+ [ 1.0 ] dip n*M.M+n*M! ;
: M. ( A B -- A.B )
- 1.0 -rot n*M.M ; inline
+ [ 1.0 ] 2dip n*M.M ; inline
:: (Msub) ( matrix row col height width -- data ld rows cols )
matrix ld>> col * row + matrix element-type heap-size *
M: blas-matrix-base equal?
{
+ [ and ]
[ [ Mwidth ] bi@ = ]
[ [ Mcols ] bi@ [ = ] 2all? ]
} 2&& ;
<<
-FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
+<FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
VECTOR IS ${TYPE}-blas-vector
<VECTOR> IS <${TYPE}-blas-vector>
->ARRAY IS >${TYPE}-array
-TYPE>ARG IS ${TYPE}>arg
-XGEMV IS cblas_${T}gemv
-XGEMM IS cblas_${T}gemm
-XGERU IS cblas_${T}ger${U}
-XGERC IS cblas_${T}ger${C}
-
-MATRIX DEFINES ${TYPE}-blas-matrix
+XGEMV IS ${T}GEMV
+XGEMM IS ${T}GEMM
+XGERU IS ${T}GER${U}
+XGERC IS ${T}GER${C}
+
+MATRIX DEFINES-CLASS ${TYPE}-blas-matrix
<MATRIX> DEFINES <${TYPE}-blas-matrix>
>MATRIX DEFINES >${TYPE}-blas-matrix
+t [ T >lower ]
+
+XMATRIX{ DEFINES ${t}matrix{
+
WHERE
TUPLE: MATRIX < blas-matrix-base ;
M: MATRIX element-type
drop TYPE ;
M: MATRIX (blas-matrix-like)
- drop <MATRIX> execute ;
+ drop <MATRIX> ;
M: VECTOR (blas-matrix-like)
- drop <MATRIX> execute ;
+ drop <MATRIX> ;
M: MATRIX (blas-vector-like)
- drop <VECTOR> execute ;
+ drop <VECTOR> ;
: >MATRIX ( arrays -- matrix )
- [ >ARRAY execute underlying>> ] (>matrix)
- <MATRIX> execute ;
+ [ TYPE >c-array underlying>> ] (>matrix) <MATRIX> ;
M: VECTOR n*M.V+n*V!
- [ TYPE>ARG execute ] (prepare-gemv)
- [ XGEMV execute ] dip ;
+ (prepare-gemv) [ XGEMV ] dip ;
M: MATRIX n*M.M+n*M!
- [ TYPE>ARG execute ] (prepare-gemm)
- [ XGEMM execute ] dip ;
+ (prepare-gemm) [ XGEMM ] dip ;
M: MATRIX n*V(*)V+M!
- [ TYPE>ARG execute ] (prepare-ger)
- [ XGERU execute ] dip ;
+ (prepare-ger) [ XGERU ] dip ;
M: MATRIX n*V(*)Vconj+M!
- [ TYPE>ARG execute ] (prepare-ger)
- [ XGERC execute ] dip ;
+ (prepare-ger) [ XGERC ] dip ;
-;FUNCTOR
+SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
+
+M: MATRIX pprint-delims
+ drop \ XMATRIX{ \ } ;
+
+;FUNCTOR>
: define-real-blas-matrix ( TYPE T -- )
"" "" (define-blas-matrix) ;
: define-complex-blas-matrix ( TYPE T -- )
- "u" "c" (define-blas-matrix) ;
+ "U" "C" (define-blas-matrix) ;
-"float" "s" define-real-blas-matrix
-"double" "d" define-real-blas-matrix
-"float-complex" "c" define-complex-blas-matrix
-"double-complex" "z" define-complex-blas-matrix
+float "S" define-real-blas-matrix
+double "D" define-real-blas-matrix
+complex-float "C" define-complex-blas-matrix
+complex-double "Z" define-complex-blas-matrix
>>
+
+M: blas-matrix-base >pprint-sequence Mrows ;
+M: blas-matrix-base pprint* pprint-object ;