]> gitweb.factorcode.org Git - factor.git/blob - extra/math/blas/matrices/matrices.factor
Merge branch 'master' into specialized-arrays
[factor.git] / extra / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types arrays byte-arrays combinators
2 combinators.lib combinators.short-circuit fry kernel locals macros
3 math math.blas.cblas math.blas.vectors math.blas.vectors.private
4 math.complex math.functions math.order multi-methods qualified
5 sequences sequences.merged sequences.private generalizations
6 shuffle symbols speicalized-arrays.float specialized-arrays.double ;
7 QUALIFIED: syntax
8 IN: math.blas.matrices
9
10 TUPLE: blas-matrix-base data ld rows cols transpose ;
11 TUPLE: float-blas-matrix < blas-matrix-base ;
12 TUPLE: double-blas-matrix < blas-matrix-base ;
13 TUPLE: float-complex-blas-matrix < blas-matrix-base ;
14 TUPLE: double-complex-blas-matrix < blas-matrix-base ;
15
16 C: <float-blas-matrix> float-blas-matrix
17 C: <double-blas-matrix> double-blas-matrix
18 C: <float-complex-blas-matrix> float-complex-blas-matrix
19 C: <double-complex-blas-matrix> double-complex-blas-matrix
20
21 METHOD: element-type { float-blas-matrix }
22     drop "float" ;
23 METHOD: element-type { double-blas-matrix }
24     drop "double" ;
25 METHOD: element-type { float-complex-blas-matrix }
26     drop "CBLAS_C" ;
27 METHOD: element-type { double-complex-blas-matrix }
28     drop "CBLAS_Z" ;
29
30 : Mtransposed? ( matrix -- ? )
31     transpose>> ; inline
32 : Mwidth ( matrix -- width )
33     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
34 : Mheight ( matrix -- height )
35     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
36
37 <PRIVATE
38
39 : (blas-transpose) ( matrix -- integer )
40     transpose>> [ CblasTrans ] [ CblasNoTrans ] if ;
41
42 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
43
44 METHOD: (blas-matrix-like) { object object object object object float-blas-matrix }
45     drop <float-blas-matrix> ;
46 METHOD: (blas-matrix-like) { object object object object object double-blas-matrix }
47     drop <double-blas-matrix> ;
48 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-matrix }
49     drop <float-complex-blas-matrix> ;
50 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-matrix }
51     drop <double-complex-blas-matrix> ;
52
53 METHOD: (blas-matrix-like) { object object object object object float-blas-vector }
54     drop <float-blas-matrix> ;
55 METHOD: (blas-matrix-like) { object object object object object double-blas-vector }
56     drop <double-blas-matrix> ;
57 METHOD: (blas-matrix-like) { object object object object object float-complex-blas-vector }
58     drop <float-complex-blas-matrix> ;
59 METHOD: (blas-matrix-like) { object object object object object double-complex-blas-vector }
60     drop <double-complex-blas-matrix> ;
61
62 METHOD: (blas-vector-like) { object object object float-blas-matrix }
63     drop <float-blas-vector> ;
64 METHOD: (blas-vector-like) { object object object double-blas-matrix }
65     drop <double-blas-vector> ;
66 METHOD: (blas-vector-like) { object object object float-complex-blas-matrix }
67     drop <float-complex-blas-vector> ;
68 METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
69     drop <double-complex-blas-vector> ;
70
71 : (validate-gemv) ( A x y -- )
72     {
73         [ drop [ Mwidth  ] [ length>> ] bi* = ]
74         [ nip  [ Mheight ] [ length>> ] bi* = ]
75     } 3&&
76     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] unless ;
77
78 :: (prepare-gemv) ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc y )
79     A x y (validate-gemv)
80     CblasColMajor
81     A (blas-transpose)
82     A rows>>
83     A cols>>
84     alpha >c-arg call
85     A data>>
86     A ld>>
87     x data>>
88     x inc>>
89     beta >c-arg call
90     y data>>
91     y inc>>
92     y ; inline
93
94 : (validate-ger) ( x y A -- )
95     {
96         [ nip  [ length>> ] [ Mheight ] bi* = ]
97         [ nipd [ length>> ] [ Mwidth  ] bi* = ]
98     } 3&&
99     [ "Mismatched vertices and matrix in vector outer product" throw ] unless ;
100
101 :: (prepare-ger) ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld A )
102     x y A (validate-ger)
103     CblasColMajor
104     A rows>>
105     A cols>>
106     alpha >c-arg call
107     x data>>
108     x inc>>
109     y data>>
110     y inc>>
111     A data>>
112     A ld>>
113     A f >>transpose ; inline
114
115 : (validate-gemm) ( A B C -- )
116     {
117         [ drop [ Mwidth  ] [ Mheight ] bi* = ]
118         [ nip  [ Mheight ] bi@ = ]
119         [ nipd [ Mwidth  ] bi@ = ]
120     } 3&& [ "Mismatched matrices in matrix multiplication" throw ] unless ;
121
122 :: (prepare-gemm) ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld C )
123     A B C (validate-gemm)
124     CblasColMajor
125     A (blas-transpose)
126     B (blas-transpose)
127     C rows>>
128     C cols>>
129     A Mwidth
130     alpha >c-arg call
131     A data>>
132     A ld>>
133     B data>>
134     B ld>>
135     beta >c-arg call
136     C data>>
137     C ld>>
138     C f >>transpose ; inline
139
140 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
141     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
142
143 PRIVATE>
144
145 : >float-blas-matrix ( arrays -- matrix )
146     [ >float-array underlying>> ] (>matrix) <float-blas-matrix> ;
147 : >double-blas-matrix ( arrays -- matrix )
148     [ >double-array underlying>> ] (>matrix) <double-blas-matrix> ;
149 : >float-complex-blas-matrix ( arrays -- matrix )
150     [ (flatten-complex-sequence) >float-array underlying>> ] (>matrix)
151     <float-complex-blas-matrix> ;
152 : >double-complex-blas-matrix ( arrays -- matrix )
153     [ (flatten-complex-sequence) >double-array underlying>> ] (>matrix)
154     <double-complex-blas-matrix> ;
155
156 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
157 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
158 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
159 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
160
161 METHOD: n*M.V+n*V! { real float-blas-matrix float-blas-vector real float-blas-vector }
162     [ ] (prepare-gemv) [ cblas_sgemv ] dip ;
163 METHOD: n*M.V+n*V! { real double-blas-matrix double-blas-vector real double-blas-vector }
164     [ ] (prepare-gemv) [ cblas_dgemv ] dip ;
165 METHOD: n*M.V+n*V! { number float-complex-blas-matrix float-complex-blas-vector number float-complex-blas-vector }
166     [ (>c-complex) ] (prepare-gemv) [ cblas_cgemv ] dip ;
167 METHOD: n*M.V+n*V! { number double-complex-blas-matrix double-complex-blas-vector number double-complex-blas-vector }
168     [ (>z-complex) ] (prepare-gemv) [ cblas_zgemv ] dip ;
169
170 METHOD: n*V(*)V+M! { real float-blas-vector float-blas-vector float-blas-matrix }
171     [ ] (prepare-ger) [ cblas_sger ] dip ;
172 METHOD: n*V(*)V+M! { real double-blas-vector double-blas-vector double-blas-matrix }
173     [ ] (prepare-ger) [ cblas_dger ] dip ;
174 METHOD: n*V(*)V+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
175     [ (>c-complex) ] (prepare-ger) [ cblas_cgeru ] dip ;
176 METHOD: n*V(*)V+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
177     [ (>z-complex) ] (prepare-ger) [ cblas_zgeru ] dip ;
178
179 METHOD: n*V(*)Vconj+M! { real float-blas-vector float-blas-vector float-blas-matrix }
180     [ ] (prepare-ger) [ cblas_sger ] dip ;
181 METHOD: n*V(*)Vconj+M! { real double-blas-vector double-blas-vector double-blas-matrix }
182     [ ] (prepare-ger) [ cblas_dger ] dip ;
183 METHOD: n*V(*)Vconj+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
184     [ (>c-complex) ] (prepare-ger) [ cblas_cgerc ] dip ;
185 METHOD: n*V(*)Vconj+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
186     [ (>z-complex) ] (prepare-ger) [ cblas_zgerc ] dip ;
187
188 METHOD: n*M.M+n*M! { real float-blas-matrix float-blas-matrix real float-blas-matrix }
189     [ ] (prepare-gemm) [ cblas_sgemm ] dip ;
190 METHOD: n*M.M+n*M! { real double-blas-matrix double-blas-matrix real double-blas-matrix }
191     [ ] (prepare-gemm) [ cblas_dgemm ] dip ;
192 METHOD: n*M.M+n*M! { number float-complex-blas-matrix float-complex-blas-matrix number float-complex-blas-matrix }
193     [ (>c-complex) ] (prepare-gemm) [ cblas_cgemm ] dip ;
194 METHOD: n*M.M+n*M! { number double-complex-blas-matrix double-complex-blas-matrix number double-complex-blas-matrix }
195     [ (>z-complex) ] (prepare-gemm) [ cblas_zgemm ] dip ;
196
197 ! XXX should do a dense clone
198 syntax:M: blas-matrix-base clone
199     [ 
200         [
201             { [ data>> ] [ ld>> ] [ cols>> ] [ element-type heap-size ] } cleave
202             * * memory>byte-array
203         ] [ { [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> ] } cleave ] bi
204     ] keep (blas-matrix-like) ;
205
206 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
207 : <empty-matrix> ( rows cols exemplar -- matrix )
208     [ element-type [ * ] dip <c-array> ]
209     [ 2drop ]
210     [ f swap (blas-matrix-like) ] 3tri ;
211
212 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
213     clone n*M.V+n*V! ;
214 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
215     clone n*V(*)V+M! ;
216 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
217     clone n*V(*)Vconj+M! ;
218 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
219     clone n*M.M+n*M! ;
220
221 : n*M.V ( alpha A x -- alpha*A.x )
222     1.0 2over [ Mheight ] dip <empty-vector>
223     n*M.V+n*V! ; inline
224
225 : M.V ( A x -- A.x )
226     1.0 -rot n*M.V ; inline
227
228 : n*V(*)V ( alpha x y -- alpha*x(*)y )
229     2dup [ length>> ] bi@ pick <empty-matrix>
230     n*V(*)V+M! ;
231 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
232     2dup [ length>> ] bi@ pick <empty-matrix>
233     n*V(*)Vconj+M! ;
234
235 : V(*) ( x y -- x(*)y )
236     1.0 -rot n*V(*)V ; inline
237 : V(*)conj ( x y -- x(*)yconj )
238     1.0 -rot n*V(*)Vconj ; inline
239
240 : n*M.M ( alpha A B -- alpha*A.B )
241     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix> 
242     1.0 swap n*M.M+n*M! ;
243
244 : M. ( A B -- A.B )
245     1.0 -rot n*M.M ; inline
246
247 :: (Msub) ( matrix row col height width -- data ld rows cols )
248     matrix ld>> col * row + matrix element-type heap-size *
249     matrix data>> <displaced-alien>
250     matrix ld>>
251     height
252     width ;
253
254 : Msub ( matrix row col height width -- sub )
255     5 npick dup transpose>>
256     [ nip [ [ swap ] 2dip swap ] when (Msub) ] 2keep
257     swap (blas-matrix-like) ;
258
259 TUPLE: blas-matrix-rowcol-sequence parent inc rowcol-length rowcol-jump length ;
260 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
261
262 INSTANCE: blas-matrix-rowcol-sequence sequence
263
264 syntax:M: blas-matrix-rowcol-sequence length
265     length>> ;
266 syntax:M: blas-matrix-rowcol-sequence nth-unsafe
267     {
268         [
269             [ rowcol-jump>> ]
270             [ parent>> element-type heap-size ]
271             [ parent>> data>> ] tri
272             [ * * ] dip <displaced-alien>
273         ]
274         [ rowcol-length>> ]
275         [ inc>> ]
276         [ parent>> ]
277     } cleave (blas-vector-like) ;
278
279 : (Mcols) ( A -- columns )
280     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } cleave
281     <blas-matrix-rowcol-sequence> ;
282 : (Mrows) ( A -- rows )
283     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } cleave
284     <blas-matrix-rowcol-sequence> ;
285
286 : Mrows ( A -- rows )
287     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
288 : Mcols ( A -- cols )
289     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
290
291 : n*M! ( n A -- A=n*A )
292     [ (Mcols) [ n*V! drop ] with each ] keep ;
293
294 : n*M ( n A -- n*A )
295     clone n*M! ; inline
296
297 : M*n ( A n -- A*n )
298     swap n*M ; inline
299 : M/n ( A n -- A/n )
300     recip swap n*M ; inline
301
302 : Mtranspose ( matrix -- matrix^T )
303     [ { [ data>> ] [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> not ] } cleave ] keep (blas-matrix-like) ;
304
305 syntax:M: blas-matrix-base equal?
306     {
307         [ [ Mwidth ] bi@ = ]
308         [ [ Mcols ] bi@ [ = ] 2all? ]
309     } 2&& ;
310