1 USING: accessors alien alien.c-types alien.complex
2 alien.data arrays byte-arrays combinators
3 combinators.short-circuit fry kernel locals macros math
4 math.blas.ffi math.blas.vectors math.blas.vectors.private
5 math.complex math.functions math.order functors words
6 sequences sequences.merged sequences.private shuffle
7 parser prettyprint.backend prettyprint.custom ascii
9 FROM: alien.c-types => float ;
10 SPECIALIZED-ARRAY: float
11 SPECIALIZED-ARRAY: double
12 SPECIALIZED-ARRAY: complex-float
13 SPECIALIZED-ARRAY: complex-double
14 IN: math.blas.matrices
16 TUPLE: blas-matrix-base underlying ld rows cols transpose ;
18 : Mtransposed? ( matrix -- ? )
20 : Mwidth ( matrix -- width )
21 dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
22 : Mheight ( matrix -- height )
23 dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
25 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
26 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
27 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
28 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
32 : (blas-transpose) ( matrix -- integer )
33 transpose>> [ "T" ] [ "N" ] if ;
35 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
37 : (validate-gemv) ( A x y -- )
39 [ drop [ Mwidth ] [ length>> ] bi* = ]
40 [ nip [ Mheight ] [ length>> ] bi* = ]
42 [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
46 ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
62 : (validate-ger) ( x y A -- )
64 [ [ length>> ] [ drop ] [ Mheight ] tri* = ]
65 [ [ drop ] [ length>> ] [ Mwidth ] tri* = ]
67 [ "Mismatched vertices and matrix in vector outer product" throw ]
71 ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
83 A f >>transpose ; inline
85 : (validate-gemm) ( A B C -- )
87 [ [ Mwidth ] [ Mheight ] [ drop ] tri* = ]
88 [ [ Mheight ] [ drop ] [ Mheight ] tri* = ]
89 [ [ drop ] [ Mwidth ] [ Mwidth ] tri* = ]
91 [ "Mismatched matrices in matrix multiplication" throw ]
95 ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
111 C f >>transpose ; inline
113 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
114 '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
118 ! XXX should do a dense clone
119 M: blas-matrix-base clone
125 [ element-type heap-size ]
126 } cleave * * memory>byte-array ]
134 ] keep (blas-matrix-like) ;
136 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
137 : <empty-matrix> ( rows cols exemplar -- matrix )
138 [ element-type heap-size * * <byte-array> ]
140 [ [ f ] dip (blas-matrix-like) ] 3tri ;
142 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
144 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
146 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
147 clone n*V(*)Vconj+M! ;
148 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
151 : n*M.V ( alpha A x -- alpha*A.x )
152 1.0 2over [ Mheight ] dip <empty-vector>
156 [ 1.0 ] 2dip n*M.V ; inline
158 : n*V(*)V ( alpha x y -- alpha*x(*)y )
159 2dup [ length>> ] bi@ pick <empty-matrix>
161 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
162 2dup [ length>> ] bi@ pick <empty-matrix>
165 : V(*) ( x y -- x(*)y )
166 [ 1.0 ] 2dip n*V(*)V ; inline
167 : V(*)conj ( x y -- x(*)yconj )
168 [ 1.0 ] 2dip n*V(*)Vconj ; inline
170 : n*M.M ( alpha A B -- alpha*A.B )
171 2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
172 [ 1.0 ] dip n*M.M+n*M! ;
175 [ 1.0 ] 2dip n*M.M ; inline
177 :: (Msub) ( matrix row col height width -- data ld rows cols )
178 matrix ld>> col * row + matrix element-type heap-size *
179 matrix underlying>> <displaced-alien>
184 :: Msub ( matrix row col height width -- sub )
185 matrix dup transpose>>
186 [ col row width height ]
187 [ row col height width ] if (Msub)
188 matrix transpose>> matrix (blas-matrix-like) ;
190 TUPLE: blas-matrix-rowcol-sequence
191 parent inc rowcol-length rowcol-jump length ;
192 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
194 INSTANCE: blas-matrix-rowcol-sequence sequence
196 M: blas-matrix-rowcol-sequence length
198 M: blas-matrix-rowcol-sequence nth-unsafe
202 [ parent>> element-type heap-size ]
203 [ parent>> underlying>> ] tri
204 [ * * ] dip <displaced-alien>
209 } cleave (blas-vector-like) ;
211 : (Mcols) ( A -- columns )
212 { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
213 cleave <blas-matrix-rowcol-sequence> ;
214 : (Mrows) ( A -- rows )
215 { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
216 cleave <blas-matrix-rowcol-sequence> ;
218 : Mrows ( A -- rows )
219 dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
220 : Mcols ( A -- cols )
221 dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
223 : n*M! ( n A -- A=n*A )
224 [ (Mcols) [ n*V! drop ] with each ] keep ;
232 recip swap n*M ; inline
234 : Mtranspose ( matrix -- matrix^T )
240 } cleave ] keep (blas-matrix-like) ;
242 M: blas-matrix-base equal?
245 [ [ Mcols ] bi@ [ = ] 2all? ]
250 FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
252 VECTOR IS ${TYPE}-blas-vector
253 <VECTOR> IS <${TYPE}-blas-vector>
259 MATRIX DEFINES-CLASS ${TYPE}-blas-matrix
260 <MATRIX> DEFINES <${TYPE}-blas-matrix>
261 >MATRIX DEFINES >${TYPE}-blas-matrix
265 XMATRIX{ DEFINES ${t}matrix{
269 TUPLE: MATRIX < blas-matrix-base ;
270 : <MATRIX> ( underlying ld rows cols transpose -- matrix )
273 M: MATRIX element-type
275 M: MATRIX (blas-matrix-like)
277 M: VECTOR (blas-matrix-like)
279 M: MATRIX (blas-vector-like)
282 : >MATRIX ( arrays -- matrix )
283 [ TYPE >c-array underlying>> ] (>matrix) <MATRIX> ;
286 (prepare-gemv) [ XGEMV ] dip ;
288 (prepare-gemm) [ XGEMM ] dip ;
290 (prepare-ger) [ XGERU ] dip ;
291 M: MATRIX n*V(*)Vconj+M!
292 (prepare-ger) [ XGERC ] dip ;
294 SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
296 M: MATRIX pprint-delims
297 drop \ XMATRIX{ \ } ;
302 : define-real-blas-matrix ( TYPE T -- )
303 "" "" (define-blas-matrix) ;
304 : define-complex-blas-matrix ( TYPE T -- )
305 "U" "C" (define-blas-matrix) ;
307 float "S" define-real-blas-matrix
308 double "D" define-real-blas-matrix
309 complex-float "C" define-complex-blas-matrix
310 complex-double "Z" define-complex-blas-matrix
314 M: blas-matrix-base >pprint-sequence Mrows ;
315 M: blas-matrix-base pprint* pprint-object ;