]> gitweb.factorcode.org Git - factor.git/blob - extra/math/blas/matrices/matrices.factor
Merge branch 'master' of git://factorcode.org/git/factor
[factor.git] / extra / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types arrays byte-arrays combinators
2 combinators.lib combinators.short-circuit fry kernel locals macros
3 math math.blas.cblas math.blas.vectors math.blas.vectors.private
4 math.complex math.functions math.order multi-methods qualified
5 sequences sequences.merged sequences.private generalizations
6 shuffle symbols ;
7 QUALIFIED: syntax
8 IN: math.blas.matrices
9
10 TUPLE: blas-matrix-base data ld rows cols transpose ;
11 TUPLE: float-blas-matrix < blas-matrix-base ;
12 TUPLE: double-blas-matrix < blas-matrix-base ;
13 TUPLE: float-complex-blas-matrix < blas-matrix-base ;
14 TUPLE: double-complex-blas-matrix < blas-matrix-base ;
15
16 C: <float-blas-matrix> float-blas-matrix
17 C: <double-blas-matrix> double-blas-matrix
18 C: <float-complex-blas-matrix> float-complex-blas-matrix
19 C: <double-complex-blas-matrix> double-complex-blas-matrix
20
21 METHOD: element-type { float-blas-matrix }
22     drop "float" ;
23 METHOD: element-type { double-blas-matrix }
24     drop "double" ;
25 METHOD: element-type { float-complex-blas-matrix }
26     drop "CBLAS_C" ;
27 METHOD: element-type { double-complex-blas-matrix }
28     drop "CBLAS_Z" ;
29
30 : Mtransposed? ( matrix -- ? )
31     transpose>> ; inline
32 : Mwidth ( matrix -- width )
33     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
34 : Mheight ( matrix -- height )
35     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
36
37 <PRIVATE
38
39 : (blas-transpose) ( matrix -- integer )
40     transpose>> [ CblasTrans ] [ CblasNoTrans ] if ;
41
42 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
43
44 METHOD: (blas-matrix-like) { object object object object object float-blas-matrix }
45     drop <float-blas-matrix> ;
46 METHOD: (blas-matrix-like) { object object object object object double-blas-matrix }
47     drop <double-blas-matrix> ;
48 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-matrix }
49     drop <float-complex-blas-matrix> ;
50 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-matrix }
51     drop <double-complex-blas-matrix> ;
52
53 METHOD: (blas-matrix-like) { object object object object object float-blas-vector }
54     drop <float-blas-matrix> ;
55 METHOD: (blas-matrix-like) { object object object object object double-blas-vector }
56     drop <double-blas-matrix> ;
57 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-vector }
58     drop <float-complex-blas-matrix> ;
59 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-vector }
60     drop <double-complex-blas-matrix> ;
61
62 METHOD: (blas-vector-like) { object object object float-blas-matrix }
63     drop <float-blas-vector> ;
64 METHOD: (blas-vector-like) { object object object double-blas-matrix }
65     drop <double-blas-vector> ;
66 METHOD: (blas-vector-like) { object object object float-complex-blas-matrix }
67     drop <float-complex-blas-vector> ;
68 METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
69     drop <double-complex-blas-vector> ;
70
71 : (validate-gemv) ( A x y -- )
72     {
73         [ drop [ Mwidth  ] [ length>> ] bi* = ]
74         [ nip  [ Mheight ] [ length>> ] bi* = ]
75     } 3&&
76     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] unless ;
77
78 :: (prepare-gemv) ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc y )
79     A x y (validate-gemv)
80     CblasColMajor
81     A (blas-transpose)
82     A rows>>
83     A cols>>
84     alpha >c-arg call
85     A data>>
86     A ld>>
87     x data>>
88     x inc>>
89     beta >c-arg call
90     y data>>
91     y inc>>
92     y ; inline
93
94 : (validate-ger) ( x y A -- )
95     {
96         [ nip  [ length>> ] [ Mheight ] bi* = ]
97         [ nipd [ length>> ] [ Mwidth  ] bi* = ]
98     } 3&&
99     [ "Mismatched vertices and matrix in vector outer product" throw ] unless ;
100
101 :: (prepare-ger) ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld A )
102     x y A (validate-ger)
103     CblasColMajor
104     A rows>>
105     A cols>>
106     alpha >c-arg call
107     x data>>
108     x inc>>
109     y data>>
110     y inc>>
111     A data>>
112     A ld>>
113     A f >>transpose ; inline
114
115 : (validate-gemm) ( A B C -- )
116     {
117         [ drop [ Mwidth  ] [ Mheight ] bi* = ]
118         [ nip  [ Mheight ] bi@ = ]
119         [ nipd [ Mwidth  ] bi@ = ]
120     } 3&& [ "Mismatched matrices in matrix multiplication" throw ] unless ;
121
122 :: (prepare-gemm) ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld C )
123     A B C (validate-gemm)
124     CblasColMajor
125     A (blas-transpose)
126     B (blas-transpose)
127     C rows>>
128     C cols>>
129     A Mwidth
130     alpha >c-arg call
131     A data>>
132     A ld>>
133     B data>>
134     B ld>>
135     beta >c-arg call
136     C data>>
137     C ld>>
138     C f >>transpose ; inline
139
140 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
141     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
142
143 PRIVATE>
144
145 : >float-blas-matrix ( arrays -- matrix )
146     [ >c-float-array ] (>matrix) <float-blas-matrix> ;
147 : >double-blas-matrix ( arrays -- matrix )
148     [ >c-double-array ] (>matrix) <double-blas-matrix> ;
149 : >float-complex-blas-matrix ( arrays -- matrix )
150     [ (flatten-complex-sequence) >c-float-array ] (>matrix)
151     <float-complex-blas-matrix> ;
152 : >double-complex-blas-matrix ( arrays -- matrix )
153     [ (flatten-complex-sequence) >c-double-array ] (>matrix)
154     <double-complex-blas-matrix> ;
155
156 GENERIC: n*M.V+n*V-in-place ( alpha A x beta y -- y=alpha*A.x+b*y )
157 GENERIC: n*V(*)V+M-in-place ( alpha x y A -- A=alpha*x(*)y+A )
158 GENERIC: n*V(*)Vconj+M-in-place ( alpha x y A -- A=alpha*x(*)yconj+A )
159 GENERIC: n*M.M+n*M-in-place ( alpha A B beta C -- C=alpha*A.B+beta*C )
160
161 METHOD: n*M.V+n*V-in-place { real float-blas-matrix float-blas-vector real float-blas-vector }
162     [ ] (prepare-gemv) [ cblas_sgemv ] dip ;
163 METHOD: n*M.V+n*V-in-place { real double-blas-matrix double-blas-vector real double-blas-vector }
164     [ ] (prepare-gemv) [ cblas_dgemv ] dip ;
165 METHOD: n*M.V+n*V-in-place { number float-complex-blas-matrix float-complex-blas-vector number float-complex-blas-vector }
166     [ (>c-complex) ] (prepare-gemv) [ cblas_cgemv ] dip ;
167 METHOD: n*M.V+n*V-in-place { number double-complex-blas-matrix double-complex-blas-vector number double-complex-blas-vector }
168     [ (>z-complex) ] (prepare-gemv) [ cblas_zgemv ] dip ;
169
170 METHOD: n*V(*)V+M-in-place { real float-blas-vector float-blas-vector float-blas-matrix }
171     [ ] (prepare-ger) [ cblas_sger ] dip ;
172 METHOD: n*V(*)V+M-in-place { real double-blas-vector double-blas-vector double-blas-matrix }
173     [ ] (prepare-ger) [ cblas_dger ] dip ;
174 METHOD: n*V(*)V+M-in-place { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
175     [ (>c-complex) ] (prepare-ger) [ cblas_cgeru ] dip ;
176 METHOD: n*V(*)V+M-in-place { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
177     [ (>z-complex) ] (prepare-ger) [ cblas_zgeru ] dip ;
178
179 METHOD: n*V(*)Vconj+M-in-place { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
180     [ (>c-complex) ] (prepare-ger) [ cblas_cgerc ] dip ;
181 METHOD: n*V(*)Vconj+M-in-place { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
182     [ (>z-complex) ] (prepare-ger) [ cblas_zgerc ] dip ;
183
184 METHOD: n*M.M+n*M-in-place { real float-blas-matrix float-blas-matrix real float-blas-matrix }
185     [ ] (prepare-gemm) [ cblas_sgemm ] dip ;
186 METHOD: n*M.M+n*M-in-place { real double-blas-matrix double-blas-matrix real double-blas-matrix }
187     [ ] (prepare-gemm) [ cblas_dgemm ] dip ;
188 METHOD: n*M.M+n*M-in-place { number float-complex-blas-matrix float-complex-blas-matrix number float-complex-blas-matrix }
189     [ (>c-complex) ] (prepare-gemm) [ cblas_cgemm ] dip ;
190 METHOD: n*M.M+n*M-in-place { number double-complex-blas-matrix double-complex-blas-matrix number double-complex-blas-matrix }
191     [ (>z-complex) ] (prepare-gemm) [ cblas_zgemm ] dip ;
192
193 ! XXX should do a dense clone
194 syntax:M: blas-matrix-base clone
195     [ 
196         [
197             { data>> ld>> cols>> element-type } get-slots
198             heap-size * * memory>byte-array
199         ] [ { ld>> rows>> cols>> transpose>> } get-slots ] bi
200     ] keep (blas-matrix-like) ;
201
202 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
203 : <empty-matrix> ( rows cols exemplar -- matrix )
204     [ element-type [ * ] dip <c-array> ]
205     [ 2drop ]
206     [ f swap (blas-matrix-like) ] 3tri ;
207
208 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
209     clone n*M.V+n*V-in-place ;
210 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
211     clone n*V(*)V+M-in-place ;
212 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
213     clone n*V(*)Vconj+M-in-place ;
214 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
215     clone n*M.M+n*M-in-place ;
216
217 : n*M.V ( alpha A x -- alpha*A.x )
218     1.0 2over [ Mheight ] dip <empty-vector>
219     n*M.V+n*V-in-place ; inline
220
221 : M.V ( A x -- A.x )
222     1.0 -rot n*M.V ; inline
223
224 : n*V(*)V ( n x y -- n*x(*)y )
225     2dup [ length>> ] bi@ pick <empty-matrix>
226     n*V(*)V+M-in-place ;
227 : n*V(*)Vconj ( n x y -- n*x(*)yconj )
228     2dup [ length>> ] bi@ pick <empty-matrix>
229     n*V(*)Vconj+M-in-place ;
230
231 : V(*) ( x y -- x(*)y )
232     1.0 -rot n*V(*)V ; inline
233 : V(*)conj ( x y -- x(*)yconj )
234     1.0 -rot n*V(*)Vconj ; inline
235
236 : n*M.M ( n A B -- n*A.B )
237     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix> 
238     1.0 swap n*M.M+n*M-in-place ;
239
240 : M. ( A B -- A.B )
241     1.0 -rot n*M.M ; inline
242
243 :: (Msub) ( matrix row col height width -- data ld rows cols )
244     matrix ld>> col * row + matrix element-type heap-size *
245     matrix data>> <displaced-alien>
246     matrix ld>>
247     height
248     width ;
249
250 : Msub ( matrix row col height width -- submatrix )
251     5 npick dup transpose>>
252     [ nip [ [ swap ] 2dip swap ] when (Msub) ] 2keep
253     swap (blas-matrix-like) ;
254
255 TUPLE: blas-matrix-rowcol-sequence parent inc rowcol-length rowcol-jump length ;
256 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
257
258 INSTANCE: blas-matrix-rowcol-sequence sequence
259
260 syntax:M: blas-matrix-rowcol-sequence length
261     length>> ;
262 syntax:M: blas-matrix-rowcol-sequence nth-unsafe
263     {
264         [
265             [ rowcol-jump>> ]
266             [ parent>> element-type heap-size ]
267             [ parent>> data>> ] tri
268             [ * * ] dip <displaced-alien>
269         ]
270         [ rowcol-length>> ]
271         [ inc>> ]
272         [ parent>> ]
273     } cleave (blas-vector-like) ;
274
275 : (Mcols) ( A -- columns )
276     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } cleave
277     <blas-matrix-rowcol-sequence> ;
278 : (Mrows) ( A -- rows )
279     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } cleave
280     <blas-matrix-rowcol-sequence> ;
281
282 : Mrows ( A -- rows )
283     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
284 : Mcols ( A -- rows )
285     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
286
287 : n*M-in-place ( n A -- A=n*A )
288     [ (Mcols) [ n*V-in-place drop ] with each ] keep ;
289
290 : n*M ( n A -- n*A )
291     clone n*M-in-place ; inline
292
293 : M*n ( A n -- A*n )
294     swap n*M ; inline
295 : M/n ( A n -- A/n )
296     recip swap n*M ; inline
297
298 : Mtranspose ( matrix -- matrix^T )
299     [ { data>> ld>> rows>> cols>> transpose>> } get-slots not ] keep (blas-matrix-like) ;
300
301 syntax:M: blas-matrix-base equal?
302     {
303         [ [ Mwidth ] bi@ = ]
304         [ [ Mcols ] bi@ [ = ] 2all? ]
305     } 2&& ;
306