]> gitweb.factorcode.org Git - factor.git/blob - basis/math/blas/matrices/matrices.factor
c315021ed4765cbb441c45b82a29ececc9a60905
[factor.git] / basis / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types arrays byte-arrays combinators
2 combinators.short-circuit fry kernel locals macros
3 math math.blas.ffi math.blas.vectors math.blas.vectors.private
4 math.complex math.functions math.order functors words
5 sequences sequences.merged sequences.private shuffle
6 specialized-arrays.float specialized-arrays.double
7 specialized-arrays.complex-float specialized-arrays.complex-double
8 parser prettyprint.backend prettyprint.custom ascii ;
9 IN: math.blas.matrices
10
11 TUPLE: blas-matrix-base underlying ld rows cols transpose ;
12
13 : Mtransposed? ( matrix -- ? )
14     transpose>> ; inline
15 : Mwidth ( matrix -- width )
16     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
17 : Mheight ( matrix -- height )
18     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
19
20 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
21 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
22 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
23 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
24
25 <PRIVATE
26
27 : (blas-transpose) ( matrix -- integer )
28     transpose>> [ "T" ] [ "N" ] if ;
29
30 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
31
32 : (validate-gemv) ( A x y -- )
33     {
34         [ drop [ Mwidth  ] [ length>> ] bi* = ]
35         [ nip  [ Mheight ] [ length>> ] bi* = ]
36     } 3&&
37     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
38     unless ;
39
40 :: (prepare-gemv)
41     ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
42                           y )
43     A x y (validate-gemv)
44     A (blas-transpose)
45     A rows>>
46     A cols>>
47     alpha
48     A
49     A ld>>
50     x
51     x inc>>
52     beta
53     y
54     y inc>>
55     y ; inline
56
57 : (validate-ger) ( x y A -- )
58     {
59         [ [ length>> ] [ drop     ] [ Mheight ] tri* = ]
60         [ [ drop     ] [ length>> ] [ Mwidth  ] tri* = ]
61     } 3&&
62     [ "Mismatched vertices and matrix in vector outer product" throw ]
63     unless ;
64
65 :: (prepare-ger)
66     ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
67                      A )
68     x y A (validate-ger)
69     A rows>>
70     A cols>>
71     alpha
72     x
73     x inc>>
74     y
75     y inc>>
76     A
77     A ld>>
78     A f >>transpose ; inline
79
80 : (validate-gemm) ( A B C -- )
81     {
82         [ [ Mwidth  ] [ Mheight ] [ drop    ] tri* = ]
83         [ [ Mheight ] [ drop    ] [ Mheight ] tri* = ]
84         [ [ drop    ] [ Mwidth  ] [ Mwidth  ] tri* = ]
85     } 3&&
86     [ "Mismatched matrices in matrix multiplication" throw ]
87     unless ;
88
89 :: (prepare-gemm)
90     ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
91                           C )
92     A B C (validate-gemm)
93     A (blas-transpose)
94     B (blas-transpose)
95     C rows>>
96     C cols>>
97     A Mwidth
98     alpha
99     A
100     A ld>>
101     B
102     B ld>>
103     beta
104     C
105     C ld>>
106     C f >>transpose ; inline
107
108 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
109     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
110
111 PRIVATE>
112
113 ! XXX should do a dense clone
114 M: blas-matrix-base clone
115     [ 
116         [ {
117             [ underlying>> ]
118             [ ld>> ]
119             [ cols>> ]
120             [ element-type heap-size ]
121         } cleave * * memory>byte-array ]
122         [ {
123             [ ld>> ]
124             [ rows>> ]
125             [ cols>> ]
126             [ transpose>> ]
127         } cleave ]
128         bi
129     ] keep (blas-matrix-like) ;
130
131 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
132 : <empty-matrix> ( rows cols exemplar -- matrix )
133     [ element-type heap-size * * <byte-array> ]
134     [ 2drop ]
135     [ f swap (blas-matrix-like) ] 3tri ;
136
137 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
138     clone n*M.V+n*V! ;
139 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
140     clone n*V(*)V+M! ;
141 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
142     clone n*V(*)Vconj+M! ;
143 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
144     clone n*M.M+n*M! ;
145
146 : n*M.V ( alpha A x -- alpha*A.x )
147     1.0 2over [ Mheight ] dip <empty-vector>
148     n*M.V+n*V! ; inline
149
150 : M.V ( A x -- A.x )
151     1.0 -rot n*M.V ; inline
152
153 : n*V(*)V ( alpha x y -- alpha*x(*)y )
154     2dup [ length>> ] bi@ pick <empty-matrix>
155     n*V(*)V+M! ;
156 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
157     2dup [ length>> ] bi@ pick <empty-matrix>
158     n*V(*)Vconj+M! ;
159
160 : V(*) ( x y -- x(*)y )
161     1.0 -rot n*V(*)V ; inline
162 : V(*)conj ( x y -- x(*)yconj )
163     1.0 -rot n*V(*)Vconj ; inline
164
165 : n*M.M ( alpha A B -- alpha*A.B )
166     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix> 
167     1.0 swap n*M.M+n*M! ;
168
169 : M. ( A B -- A.B )
170     1.0 -rot n*M.M ; inline
171
172 :: (Msub) ( matrix row col height width -- data ld rows cols )
173     matrix ld>> col * row + matrix element-type heap-size *
174     matrix underlying>> <displaced-alien>
175     matrix ld>>
176     height
177     width ;
178
179 :: Msub ( matrix row col height width -- sub )
180     matrix dup transpose>>
181     [ col row width height ]
182     [ row col height width ] if (Msub)
183     matrix transpose>> matrix (blas-matrix-like) ;
184
185 TUPLE: blas-matrix-rowcol-sequence
186     parent inc rowcol-length rowcol-jump length ;
187 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
188
189 INSTANCE: blas-matrix-rowcol-sequence sequence
190
191 M: blas-matrix-rowcol-sequence length
192     length>> ;
193 M: blas-matrix-rowcol-sequence nth-unsafe
194     {
195         [
196             [ rowcol-jump>> ]
197             [ parent>> element-type heap-size ]
198             [ parent>> underlying>> ] tri
199             [ * * ] dip <displaced-alien>
200         ]
201         [ rowcol-length>> ]
202         [ inc>> ]
203         [ parent>> ]
204     } cleave (blas-vector-like) ;
205
206 : (Mcols) ( A -- columns )
207     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
208     cleave <blas-matrix-rowcol-sequence> ;
209 : (Mrows) ( A -- rows )
210     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
211     cleave <blas-matrix-rowcol-sequence> ;
212
213 : Mrows ( A -- rows )
214     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
215 : Mcols ( A -- cols )
216     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
217
218 : n*M! ( n A -- A=n*A )
219     [ (Mcols) [ n*V! drop ] with each ] keep ;
220
221 : n*M ( n A -- n*A )
222     clone n*M! ; inline
223
224 : M*n ( A n -- A*n )
225     swap n*M ; inline
226 : M/n ( A n -- A/n )
227     recip swap n*M ; inline
228
229 : Mtranspose ( matrix -- matrix^T )
230     [ {
231         [ underlying>> ]
232         [ ld>> ] [ rows>> ]
233         [ cols>> ]
234         [ transpose>> not ]
235     } cleave ] keep (blas-matrix-like) ;
236
237 M: blas-matrix-base equal?
238     {
239         [ [ Mwidth ] bi@ = ]
240         [ [ Mcols ] bi@ [ = ] 2all? ]
241     } 2&& ;
242
243 <<
244
245 FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
246
247 VECTOR      IS ${TYPE}-blas-vector
248 <VECTOR>    IS <${TYPE}-blas-vector>
249 >ARRAY      IS >${TYPE}-array
250 XGEMV       IS ${T}GEMV
251 XGEMM       IS ${T}GEMM
252 XGERU       IS ${T}GER${U}
253 XGERC       IS ${T}GER${C}
254
255 MATRIX      DEFINES-CLASS ${TYPE}-blas-matrix
256 <MATRIX>    DEFINES <${TYPE}-blas-matrix>
257 >MATRIX     DEFINES >${TYPE}-blas-matrix
258
259 t           [ T >lower ]
260
261 XMATRIX{    DEFINES ${t}matrix{
262
263 WHERE
264
265 TUPLE: MATRIX < blas-matrix-base ;
266 : <MATRIX> ( underlying ld rows cols transpose -- matrix )
267     MATRIX boa ; inline
268
269 M: MATRIX element-type
270     drop TYPE ;
271 M: MATRIX (blas-matrix-like)
272     drop <MATRIX> ;
273 M: VECTOR (blas-matrix-like)
274     drop <MATRIX> ;
275 M: MATRIX (blas-vector-like)
276     drop <VECTOR> ;
277
278 : >MATRIX ( arrays -- matrix )
279     [ >ARRAY underlying>> ] (>matrix) <MATRIX> ;
280
281 M: VECTOR n*M.V+n*V!
282     (prepare-gemv) [ XGEMV ] dip ;
283 M: MATRIX n*M.M+n*M!
284     (prepare-gemm) [ XGEMM ] dip ;
285 M: MATRIX n*V(*)V+M!
286     (prepare-ger) [ XGERU ] dip ;
287 M: MATRIX n*V(*)Vconj+M!
288     (prepare-ger) [ XGERC ] dip ;
289
290 SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
291
292 M: MATRIX pprint-delims
293     drop \ XMATRIX{ \ } ;
294
295 ;FUNCTOR
296
297
298 : define-real-blas-matrix ( TYPE T -- )
299     "" "" (define-blas-matrix) ;
300 : define-complex-blas-matrix ( TYPE T -- )
301     "U" "C" (define-blas-matrix) ;
302
303 "float"          "S" define-real-blas-matrix
304 "double"         "D" define-real-blas-matrix
305 "complex-float"  "C" define-complex-blas-matrix
306 "complex-double" "Z" define-complex-blas-matrix
307
308 >>
309
310 M: blas-matrix-base >pprint-sequence Mrows ;
311 M: blas-matrix-base pprint* pprint-object ;