1 USING: accessors alien alien.c-types arrays byte-arrays combinators
2 combinators.lib combinators.short-circuit fry kernel locals macros
3 math math.blas.cblas math.blas.vectors math.blas.vectors.private
4 math.complex math.functions math.order multi-methods qualified
5 sequences sequences.merged sequences.private generalizations
10 TUPLE: blas-matrix-base data ld rows cols transpose ;
11 TUPLE: float-blas-matrix < blas-matrix-base ;
12 TUPLE: double-blas-matrix < blas-matrix-base ;
13 TUPLE: float-complex-blas-matrix < blas-matrix-base ;
14 TUPLE: double-complex-blas-matrix < blas-matrix-base ;
16 C: <float-blas-matrix> float-blas-matrix
17 C: <double-blas-matrix> double-blas-matrix
18 C: <float-complex-blas-matrix> float-complex-blas-matrix
19 C: <double-complex-blas-matrix> double-complex-blas-matrix
21 METHOD: element-type { float-blas-matrix }
23 METHOD: element-type { double-blas-matrix }
25 METHOD: element-type { float-complex-blas-matrix }
27 METHOD: element-type { double-complex-blas-matrix }
30 : Mtransposed? ( matrix -- ? )
32 : Mwidth ( matrix -- width )
33 dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
34 : Mheight ( matrix -- height )
35 dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
39 : (blas-transpose) ( matrix -- integer )
40 transpose>> [ CblasTrans ] [ CblasNoTrans ] if ;
42 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
44 METHOD: (blas-matrix-like) { object object object object object float-blas-matrix }
45 drop <float-blas-matrix> ;
46 METHOD: (blas-matrix-like) { object object object object object double-blas-matrix }
47 drop <double-blas-matrix> ;
48 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-matrix }
49 drop <float-complex-blas-matrix> ;
50 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-matrix }
51 drop <double-complex-blas-matrix> ;
53 METHOD: (blas-matrix-like) { object object object object object float-blas-vector }
54 drop <float-blas-matrix> ;
55 METHOD: (blas-matrix-like) { object object object object object double-blas-vector }
56 drop <double-blas-matrix> ;
57 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-vector }
58 drop <float-complex-blas-matrix> ;
59 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-vector }
60 drop <double-complex-blas-matrix> ;
62 METHOD: (blas-vector-like) { object object object float-blas-matrix }
63 drop <float-blas-vector> ;
64 METHOD: (blas-vector-like) { object object object double-blas-matrix }
65 drop <double-blas-vector> ;
66 METHOD: (blas-vector-like) { object object object float-complex-blas-matrix }
67 drop <float-complex-blas-vector> ;
68 METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
69 drop <double-complex-blas-vector> ;
71 : (validate-gemv) ( A x y -- )
73 [ drop [ Mwidth ] [ length>> ] bi* = ]
74 [ nip [ Mheight ] [ length>> ] bi* = ]
76 [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] unless ;
78 :: (prepare-gemv) ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc y )
94 : (validate-ger) ( x y A -- )
96 [ nip [ length>> ] [ Mheight ] bi* = ]
97 [ nipd [ length>> ] [ Mwidth ] bi* = ]
99 [ "Mismatched vertices and matrix in vector outer product" throw ] unless ;
101 :: (prepare-ger) ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld A )
113 A f >>transpose ; inline
115 : (validate-gemm) ( A B C -- )
117 [ drop [ Mwidth ] [ Mheight ] bi* = ]
118 [ nip [ Mheight ] bi@ = ]
119 [ nipd [ Mwidth ] bi@ = ]
120 } 3&& [ "Mismatched matrices in matrix multiplication" throw ] unless ;
122 :: (prepare-gemm) ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld C )
123 A B C (validate-gemm)
138 C f >>transpose ; inline
140 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
141 '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
145 : >float-blas-matrix ( arrays -- matrix )
146 [ >c-float-array ] (>matrix) <float-blas-matrix> ;
147 : >double-blas-matrix ( arrays -- matrix )
148 [ >c-double-array ] (>matrix) <double-blas-matrix> ;
149 : >float-complex-blas-matrix ( arrays -- matrix )
150 [ (flatten-complex-sequence) >c-float-array ] (>matrix)
151 <float-complex-blas-matrix> ;
152 : >double-complex-blas-matrix ( arrays -- matrix )
153 [ (flatten-complex-sequence) >c-double-array ] (>matrix)
154 <double-complex-blas-matrix> ;
156 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
157 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
158 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
159 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
161 METHOD: n*M.V+n*V! { real float-blas-matrix float-blas-vector real float-blas-vector }
162 [ ] (prepare-gemv) [ cblas_sgemv ] dip ;
163 METHOD: n*M.V+n*V! { real double-blas-matrix double-blas-vector real double-blas-vector }
164 [ ] (prepare-gemv) [ cblas_dgemv ] dip ;
165 METHOD: n*M.V+n*V! { number float-complex-blas-matrix float-complex-blas-vector number float-complex-blas-vector }
166 [ (>c-complex) ] (prepare-gemv) [ cblas_cgemv ] dip ;
167 METHOD: n*M.V+n*V! { number double-complex-blas-matrix double-complex-blas-vector number double-complex-blas-vector }
168 [ (>z-complex) ] (prepare-gemv) [ cblas_zgemv ] dip ;
170 METHOD: n*V(*)V+M! { real float-blas-vector float-blas-vector float-blas-matrix }
171 [ ] (prepare-ger) [ cblas_sger ] dip ;
172 METHOD: n*V(*)V+M! { real double-blas-vector double-blas-vector double-blas-matrix }
173 [ ] (prepare-ger) [ cblas_dger ] dip ;
174 METHOD: n*V(*)V+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
175 [ (>c-complex) ] (prepare-ger) [ cblas_cgeru ] dip ;
176 METHOD: n*V(*)V+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
177 [ (>z-complex) ] (prepare-ger) [ cblas_zgeru ] dip ;
179 METHOD: n*V(*)Vconj+M! { real float-blas-vector float-blas-vector float-blas-matrix }
180 [ ] (prepare-ger) [ cblas_sger ] dip ;
181 METHOD: n*V(*)Vconj+M! { real double-blas-vector double-blas-vector double-blas-matrix }
182 [ ] (prepare-ger) [ cblas_dger ] dip ;
183 METHOD: n*V(*)Vconj+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
184 [ (>c-complex) ] (prepare-ger) [ cblas_cgerc ] dip ;
185 METHOD: n*V(*)Vconj+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
186 [ (>z-complex) ] (prepare-ger) [ cblas_zgerc ] dip ;
188 METHOD: n*M.M+n*M! { real float-blas-matrix float-blas-matrix real float-blas-matrix }
189 [ ] (prepare-gemm) [ cblas_sgemm ] dip ;
190 METHOD: n*M.M+n*M! { real double-blas-matrix double-blas-matrix real double-blas-matrix }
191 [ ] (prepare-gemm) [ cblas_dgemm ] dip ;
192 METHOD: n*M.M+n*M! { number float-complex-blas-matrix float-complex-blas-matrix number float-complex-blas-matrix }
193 [ (>c-complex) ] (prepare-gemm) [ cblas_cgemm ] dip ;
194 METHOD: n*M.M+n*M! { number double-complex-blas-matrix double-complex-blas-matrix number double-complex-blas-matrix }
195 [ (>z-complex) ] (prepare-gemm) [ cblas_zgemm ] dip ;
197 ! XXX should do a dense clone
198 syntax:M: blas-matrix-base clone
201 { [ data>> ] [ ld>> ] [ cols>> ] [ element-type heap-size ] } cleave
202 * * memory>byte-array
203 ] [ { [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> ] } cleave ] bi
204 ] keep (blas-matrix-like) ;
206 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
207 : <empty-matrix> ( rows cols exemplar -- matrix )
208 [ element-type [ * ] dip <c-array> ]
210 [ f swap (blas-matrix-like) ] 3tri ;
212 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
214 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
216 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
217 clone n*V(*)Vconj+M! ;
218 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
221 : n*M.V ( alpha A x -- alpha*A.x )
222 1.0 2over [ Mheight ] dip <empty-vector>
226 1.0 -rot n*M.V ; inline
228 : n*V(*)V ( alpha x y -- alpha*x(*)y )
229 2dup [ length>> ] bi@ pick <empty-matrix>
231 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
232 2dup [ length>> ] bi@ pick <empty-matrix>
235 : V(*) ( x y -- x(*)y )
236 1.0 -rot n*V(*)V ; inline
237 : V(*)conj ( x y -- x(*)yconj )
238 1.0 -rot n*V(*)Vconj ; inline
240 : n*M.M ( alpha A B -- alpha*A.B )
241 2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
242 1.0 swap n*M.M+n*M! ;
245 1.0 -rot n*M.M ; inline
247 :: (Msub) ( matrix row col height width -- data ld rows cols )
248 matrix ld>> col * row + matrix element-type heap-size *
249 matrix data>> <displaced-alien>
254 : Msub ( matrix row col height width -- sub )
255 5 npick dup transpose>>
256 [ nip [ [ swap ] 2dip swap ] when (Msub) ] 2keep
257 swap (blas-matrix-like) ;
259 TUPLE: blas-matrix-rowcol-sequence parent inc rowcol-length rowcol-jump length ;
260 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
262 INSTANCE: blas-matrix-rowcol-sequence sequence
264 syntax:M: blas-matrix-rowcol-sequence length
266 syntax:M: blas-matrix-rowcol-sequence nth-unsafe
270 [ parent>> element-type heap-size ]
271 [ parent>> data>> ] tri
272 [ * * ] dip <displaced-alien>
277 } cleave (blas-vector-like) ;
279 : (Mcols) ( A -- columns )
280 { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } cleave
281 <blas-matrix-rowcol-sequence> ;
282 : (Mrows) ( A -- rows )
283 { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } cleave
284 <blas-matrix-rowcol-sequence> ;
286 : Mrows ( A -- rows )
287 dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
288 : Mcols ( A -- cols )
289 dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
291 : n*M! ( n A -- A=n*A )
292 [ (Mcols) [ n*V! drop ] with each ] keep ;
300 recip swap n*M ; inline
302 : Mtranspose ( matrix -- matrix^T )
303 [ { [ data>> ] [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> not ] } cleave ] keep (blas-matrix-like) ;
305 syntax:M: blas-matrix-base equal?
308 [ [ Mcols ] bi@ [ = ] 2all? ]