]> gitweb.factorcode.org Git - factor.git/blob - extra/math/blas/matrices/matrices.factor
Merge branch 'master' of git://repo.or.cz/factor/jcg
[factor.git] / extra / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types arrays byte-arrays combinators
2 combinators.lib combinators.short-circuit fry kernel locals macros
3 math math.blas.cblas math.blas.vectors math.blas.vectors.private
4 math.complex math.functions math.order multi-methods qualified
5 sequences sequences.merged sequences.private shuffle symbols ;
6 QUALIFIED: syntax
7 IN: math.blas.matrices
8
9 TUPLE: blas-matrix-base data ld rows cols transpose ;
10 TUPLE: float-blas-matrix < blas-matrix-base ;
11 TUPLE: double-blas-matrix < blas-matrix-base ;
12 TUPLE: float-complex-blas-matrix < blas-matrix-base ;
13 TUPLE: double-complex-blas-matrix < blas-matrix-base ;
14
15 C: <float-blas-matrix> float-blas-matrix
16 C: <double-blas-matrix> double-blas-matrix
17 C: <float-complex-blas-matrix> float-complex-blas-matrix
18 C: <double-complex-blas-matrix> double-complex-blas-matrix
19
20 METHOD: element-type { float-blas-matrix }
21     drop "float" ;
22 METHOD: element-type { double-blas-matrix }
23     drop "double" ;
24 METHOD: element-type { float-complex-blas-matrix }
25     drop "CBLAS_C" ;
26 METHOD: element-type { double-complex-blas-matrix }
27     drop "CBLAS_Z" ;
28
29 : Mtransposed? ( matrix -- ? )
30     transpose>> ; inline
31 : Mwidth ( matrix -- width )
32     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
33 : Mheight ( matrix -- height )
34     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
35
36 <PRIVATE
37
38 : (blas-transpose) ( matrix -- integer )
39     transpose>> [ CblasTrans ] [ CblasNoTrans ] if ;
40
41 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
42
43 METHOD: (blas-matrix-like) { object object object object object float-blas-matrix }
44     drop <float-blas-matrix> ;
45 METHOD: (blas-matrix-like) { object object object object object double-blas-matrix }
46     drop <double-blas-matrix> ;
47 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-matrix }
48     drop <float-complex-blas-matrix> ;
49 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-matrix }
50     drop <double-complex-blas-matrix> ;
51
52 METHOD: (blas-matrix-like) { object object object object object float-blas-vector }
53     drop <float-blas-matrix> ;
54 METHOD: (blas-matrix-like) { object object object object object double-blas-vector }
55     drop <double-blas-matrix> ;
56 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-vector }
57     drop <float-complex-blas-matrix> ;
58 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-vector }
59     drop <double-complex-blas-matrix> ;
60
61 METHOD: (blas-vector-like) { object object object float-blas-matrix }
62     drop <float-blas-vector> ;
63 METHOD: (blas-vector-like) { object object object double-blas-matrix }
64     drop <double-blas-vector> ;
65 METHOD: (blas-vector-like) { object object object float-complex-blas-matrix }
66     drop <float-complex-blas-vector> ;
67 METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
68     drop <double-complex-blas-vector> ;
69
70 : (validate-gemv) ( A x y -- )
71     {
72         [ drop [ Mwidth  ] [ length>> ] bi* = ]
73         [ nip  [ Mheight ] [ length>> ] bi* = ]
74     } 3&&
75     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] unless ;
76
77 :: (prepare-gemv) ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc y )
78     A x y (validate-gemv)
79     CblasColMajor
80     A (blas-transpose)
81     A rows>>
82     A cols>>
83     alpha >c-arg call
84     A data>>
85     A ld>>
86     x data>>
87     x inc>>
88     beta >c-arg call
89     y data>>
90     y inc>>
91     y ; inline
92
93 : (validate-ger) ( x y A -- )
94     {
95         [ nip  [ length>> ] [ Mheight ] bi* = ]
96         [ nipd [ length>> ] [ Mwidth  ] bi* = ]
97     } 3&&
98     [ "Mismatched vertices and matrix in vector outer product" throw ] unless ;
99
100 :: (prepare-ger) ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld A )
101     x y A (validate-ger)
102     CblasColMajor
103     A rows>>
104     A cols>>
105     alpha >c-arg call
106     x data>>
107     x inc>>
108     y data>>
109     y inc>>
110     A data>>
111     A ld>>
112     A f >>transpose ; inline
113
114 : (validate-gemm) ( A B C -- )
115     {
116         [ drop [ Mwidth  ] [ Mheight ] bi* = ]
117         [ nip  [ Mheight ] bi@ = ]
118         [ nipd [ Mwidth  ] bi@ = ]
119     } 3&& [ "Mismatched matrices in matrix multiplication" throw ] unless ;
120
121 :: (prepare-gemm) ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld C )
122     A B C (validate-gemm)
123     CblasColMajor
124     A (blas-transpose)
125     B (blas-transpose)
126     C rows>>
127     C cols>>
128     A Mwidth
129     alpha >c-arg call
130     A data>>
131     A ld>>
132     B data>>
133     B ld>>
134     beta >c-arg call
135     C data>>
136     C ld>>
137     C f >>transpose ; inline
138
139 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
140     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
141
142 PRIVATE>
143
144 : >float-blas-matrix ( arrays -- matrix )
145     [ >c-float-array ] (>matrix) <float-blas-matrix> ;
146 : >double-blas-matrix ( arrays -- matrix )
147     [ >c-double-array ] (>matrix) <double-blas-matrix> ;
148 : >float-complex-blas-matrix ( arrays -- matrix )
149     [ (flatten-complex-sequence) >c-float-array ] (>matrix)
150     <float-complex-blas-matrix> ;
151 : >double-complex-blas-matrix ( arrays -- matrix )
152     [ (flatten-complex-sequence) >c-double-array ] (>matrix)
153     <double-complex-blas-matrix> ;
154
155 GENERIC: n*M.V+n*V-in-place ( alpha A x beta y -- y=alpha*A.x+b*y )
156 GENERIC: n*V(*)V+M-in-place ( alpha x y A -- A=alpha*x(*)y+A )
157 GENERIC: n*V(*)Vconj+M-in-place ( alpha x y A -- A=alpha*x(*)yconj+A )
158 GENERIC: n*M.M+n*M-in-place ( alpha A B beta C -- C=alpha*A.B+beta*C )
159
160 METHOD: n*M.V+n*V-in-place { real float-blas-matrix float-blas-vector real float-blas-vector }
161     [ ] (prepare-gemv) [ cblas_sgemv ] dip ;
162 METHOD: n*M.V+n*V-in-place { real double-blas-matrix double-blas-vector real double-blas-vector }
163     [ ] (prepare-gemv) [ cblas_dgemv ] dip ;
164 METHOD: n*M.V+n*V-in-place { number float-complex-blas-matrix float-complex-blas-vector number float-complex-blas-vector }
165     [ (>c-complex) ] (prepare-gemv) [ cblas_cgemv ] dip ;
166 METHOD: n*M.V+n*V-in-place { number double-complex-blas-matrix double-complex-blas-vector number double-complex-blas-vector }
167     [ (>z-complex) ] (prepare-gemv) [ cblas_zgemv ] dip ;
168
169 METHOD: n*V(*)V+M-in-place { real float-blas-vector float-blas-vector float-blas-matrix }
170     [ ] (prepare-ger) [ cblas_sger ] dip ;
171 METHOD: n*V(*)V+M-in-place { real double-blas-vector double-blas-vector double-blas-matrix }
172     [ ] (prepare-ger) [ cblas_dger ] dip ;
173 METHOD: n*V(*)V+M-in-place { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
174     [ (>c-complex) ] (prepare-ger) [ cblas_cgeru ] dip ;
175 METHOD: n*V(*)V+M-in-place { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
176     [ (>z-complex) ] (prepare-ger) [ cblas_zgeru ] dip ;
177
178 METHOD: n*V(*)Vconj+M-in-place { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
179     [ (>c-complex) ] (prepare-ger) [ cblas_cgerc ] dip ;
180 METHOD: n*V(*)Vconj+M-in-place { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
181     [ (>z-complex) ] (prepare-ger) [ cblas_zgerc ] dip ;
182
183 METHOD: n*M.M+n*M-in-place { real float-blas-matrix float-blas-matrix real float-blas-matrix }
184     [ ] (prepare-gemm) [ cblas_sgemm ] dip ;
185 METHOD: n*M.M+n*M-in-place { real double-blas-matrix double-blas-matrix real double-blas-matrix }
186     [ ] (prepare-gemm) [ cblas_dgemm ] dip ;
187 METHOD: n*M.M+n*M-in-place { number float-complex-blas-matrix float-complex-blas-matrix number float-complex-blas-matrix }
188     [ (>c-complex) ] (prepare-gemm) [ cblas_cgemm ] dip ;
189 METHOD: n*M.M+n*M-in-place { number double-complex-blas-matrix double-complex-blas-matrix number double-complex-blas-matrix }
190     [ (>z-complex) ] (prepare-gemm) [ cblas_zgemm ] dip ;
191
192 ! XXX should do a dense clone
193 syntax:M: blas-matrix-base clone
194     [ 
195         [
196             { data>> ld>> cols>> element-type } get-slots
197             heap-size * * memory>byte-array
198         ] [ { ld>> rows>> cols>> transpose>> } get-slots ] bi
199     ] keep (blas-matrix-like) ;
200
201 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
202 : <empty-matrix> ( rows cols exemplar -- matrix )
203     [ element-type [ * ] dip <c-array> ]
204     [ 2drop ]
205     [ f swap (blas-matrix-like) ] 3tri ;
206
207 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
208     clone n*M.V+n*V-in-place ;
209 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
210     clone n*V(*)V+M-in-place ;
211 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
212     clone n*V(*)Vconj+M-in-place ;
213 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
214     clone n*M.M+n*M-in-place ;
215
216 : n*M.V ( alpha A x -- alpha*A.x )
217     1.0 2over [ Mheight ] dip <empty-vector>
218     n*M.V+n*V-in-place ; inline
219
220 : M.V ( A x -- A.x )
221     1.0 -rot n*M.V ; inline
222
223 : n*V(*)V ( n x y -- n*x(*)y )
224     2dup [ length>> ] bi@ pick <empty-matrix>
225     n*V(*)V+M-in-place ;
226 : n*V(*)Vconj ( n x y -- n*x(*)yconj )
227     2dup [ length>> ] bi@ pick <empty-matrix>
228     n*V(*)Vconj+M-in-place ;
229
230 : V(*) ( x y -- x(*)y )
231     1.0 -rot n*V(*)V ; inline
232 : V(*)conj ( x y -- x(*)yconj )
233     1.0 -rot n*V(*)Vconj ; inline
234
235 : n*M.M ( n A B -- n*A.B )
236     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix> 
237     1.0 swap n*M.M+n*M-in-place ;
238
239 : M. ( A B -- A.B )
240     1.0 -rot n*M.M ; inline
241
242 :: (Msub) ( matrix row col height width -- data ld rows cols )
243     matrix ld>> col * row + matrix element-type heap-size *
244     matrix data>> <displaced-alien>
245     matrix ld>>
246     height
247     width ;
248
249 : Msub ( matrix row col height width -- submatrix )
250     5 npick dup transpose>>
251     [ nip [ [ swap ] 2dip swap ] when (Msub) ] 2keep
252     swap (blas-matrix-like) ;
253
254 TUPLE: blas-matrix-rowcol-sequence parent inc rowcol-length rowcol-jump length ;
255 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
256
257 INSTANCE: blas-matrix-rowcol-sequence sequence
258
259 syntax:M: blas-matrix-rowcol-sequence length
260     length>> ;
261 syntax:M: blas-matrix-rowcol-sequence nth-unsafe
262     {
263         [
264             [ rowcol-jump>> ]
265             [ parent>> element-type heap-size ]
266             [ parent>> data>> ] tri
267             [ * * ] dip <displaced-alien>
268         ]
269         [ rowcol-length>> ]
270         [ inc>> ]
271         [ parent>> ]
272     } cleave (blas-vector-like) ;
273
274 : (Mcols) ( A -- columns )
275     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } cleave
276     <blas-matrix-rowcol-sequence> ;
277 : (Mrows) ( A -- rows )
278     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } cleave
279     <blas-matrix-rowcol-sequence> ;
280
281 : Mrows ( A -- rows )
282     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
283 : Mcols ( A -- rows )
284     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
285
286 : n*M-in-place ( n A -- A=n*A )
287     [ (Mcols) [ n*V-in-place drop ] with each ] keep ;
288
289 : n*M ( n A -- n*A )
290     clone n*M-in-place ; inline
291
292 : M*n ( A n -- A*n )
293     swap n*M ; inline
294 : M/n ( A n -- A/n )
295     recip swap n*M ; inline
296
297 : Mtranspose ( matrix -- matrix^T )
298     [ { data>> ld>> rows>> cols>> transpose>> } get-slots not ] keep (blas-matrix-like) ;
299
300 syntax:M: blas-matrix-base equal?
301     {
302         [ [ Mwidth ] bi@ = ]
303         [ [ Mcols ] bi@ [ = ] 2all? ]
304     } 2&& ;
305