]> gitweb.factorcode.org Git - factor.git/blob - basis/math/blas/matrices/matrices.factor
Factor source files should not be executable
[factor.git] / basis / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types alien.complex
2 alien.data arrays byte-arrays combinators
3 combinators.short-circuit fry kernel locals macros math
4 math.blas.ffi math.blas.vectors math.blas.vectors.private
5 math.complex math.functions math.order functors words
6 sequences sequences.merged sequences.private shuffle
7 parser prettyprint.backend prettyprint.custom ascii
8 specialized-arrays ;
9 FROM: alien.c-types => float ;
10 SPECIALIZED-ARRAY: float
11 SPECIALIZED-ARRAY: double
12 SPECIALIZED-ARRAY: complex-float
13 SPECIALIZED-ARRAY: complex-double
14 IN: math.blas.matrices
15
16 TUPLE: blas-matrix-base underlying ld rows cols transpose ;
17
18 : Mtransposed? ( matrix -- ? )
19     transpose>> ; inline
20 : Mwidth ( matrix -- width )
21     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
22 : Mheight ( matrix -- height )
23     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
24
25 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
26 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
27 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
28 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
29
30 <PRIVATE
31
32 : (blas-transpose) ( matrix -- integer )
33     transpose>> [ "T" ] [ "N" ] if ;
34
35 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
36
37 : (validate-gemv) ( A x y -- )
38     {
39         [ drop [ Mwidth  ] [ length>> ] bi* = ]
40         [ nip  [ Mheight ] [ length>> ] bi* = ]
41     } 3&&
42     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
43     unless ;
44
45 :: (prepare-gemv)
46     ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
47                           y )
48     A x y (validate-gemv)
49     A (blas-transpose)
50     A rows>>
51     A cols>>
52     alpha
53     A
54     A ld>>
55     x
56     x inc>>
57     beta
58     y
59     y inc>>
60     y ; inline
61
62 : (validate-ger) ( x y A -- )
63     {
64         [ [ length>> ] [ drop     ] [ Mheight ] tri* = ]
65         [ [ drop     ] [ length>> ] [ Mwidth  ] tri* = ]
66     } 3&&
67     [ "Mismatched vertices and matrix in vector outer product" throw ]
68     unless ;
69
70 :: (prepare-ger)
71     ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
72                      A )
73     x y A (validate-ger)
74     A rows>>
75     A cols>>
76     alpha
77     x
78     x inc>>
79     y
80     y inc>>
81     A
82     A ld>>
83     A f >>transpose ; inline
84
85 : (validate-gemm) ( A B C -- )
86     {
87         [ [ Mwidth  ] [ Mheight ] [ drop    ] tri* = ]
88         [ [ Mheight ] [ drop    ] [ Mheight ] tri* = ]
89         [ [ drop    ] [ Mwidth  ] [ Mwidth  ] tri* = ]
90     } 3&&
91     [ "Mismatched matrices in matrix multiplication" throw ]
92     unless ;
93
94 :: (prepare-gemm)
95     ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
96                           C )
97     A B C (validate-gemm)
98     A (blas-transpose)
99     B (blas-transpose)
100     C rows>>
101     C cols>>
102     A Mwidth
103     alpha
104     A
105     A ld>>
106     B
107     B ld>>
108     beta
109     C
110     C ld>>
111     C f >>transpose ; inline
112
113 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
114     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
115
116 PRIVATE>
117
118 ! XXX should do a dense clone
119 M: blas-matrix-base clone
120     [ 
121         [ {
122             [ underlying>> ]
123             [ ld>> ]
124             [ cols>> ]
125             [ element-type heap-size ]
126         } cleave * * memory>byte-array ]
127         [ {
128             [ ld>> ]
129             [ rows>> ]
130             [ cols>> ]
131             [ transpose>> ]
132         } cleave ]
133         bi
134     ] keep (blas-matrix-like) ;
135
136 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
137 : <empty-matrix> ( rows cols exemplar -- matrix )
138     [ element-type heap-size * * <byte-array> ]
139     [ 2drop ]
140     [ f swap (blas-matrix-like) ] 3tri ;
141
142 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
143     clone n*M.V+n*V! ;
144 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
145     clone n*V(*)V+M! ;
146 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
147     clone n*V(*)Vconj+M! ;
148 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
149     clone n*M.M+n*M! ;
150
151 : n*M.V ( alpha A x -- alpha*A.x )
152     1.0 2over [ Mheight ] dip <empty-vector>
153     n*M.V+n*V! ; inline
154
155 : M.V ( A x -- A.x )
156     1.0 -rot n*M.V ; inline
157
158 : n*V(*)V ( alpha x y -- alpha*x(*)y )
159     2dup [ length>> ] bi@ pick <empty-matrix>
160     n*V(*)V+M! ;
161 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
162     2dup [ length>> ] bi@ pick <empty-matrix>
163     n*V(*)Vconj+M! ;
164
165 : V(*) ( x y -- x(*)y )
166     1.0 -rot n*V(*)V ; inline
167 : V(*)conj ( x y -- x(*)yconj )
168     1.0 -rot n*V(*)Vconj ; inline
169
170 : n*M.M ( alpha A B -- alpha*A.B )
171     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix> 
172     1.0 swap n*M.M+n*M! ;
173
174 : M. ( A B -- A.B )
175     1.0 -rot n*M.M ; inline
176
177 :: (Msub) ( matrix row col height width -- data ld rows cols )
178     matrix ld>> col * row + matrix element-type heap-size *
179     matrix underlying>> <displaced-alien>
180     matrix ld>>
181     height
182     width ;
183
184 :: Msub ( matrix row col height width -- sub )
185     matrix dup transpose>>
186     [ col row width height ]
187     [ row col height width ] if (Msub)
188     matrix transpose>> matrix (blas-matrix-like) ;
189
190 TUPLE: blas-matrix-rowcol-sequence
191     parent inc rowcol-length rowcol-jump length ;
192 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
193
194 INSTANCE: blas-matrix-rowcol-sequence sequence
195
196 M: blas-matrix-rowcol-sequence length
197     length>> ;
198 M: blas-matrix-rowcol-sequence nth-unsafe
199     {
200         [
201             [ rowcol-jump>> ]
202             [ parent>> element-type heap-size ]
203             [ parent>> underlying>> ] tri
204             [ * * ] dip <displaced-alien>
205         ]
206         [ rowcol-length>> ]
207         [ inc>> ]
208         [ parent>> ]
209     } cleave (blas-vector-like) ;
210
211 : (Mcols) ( A -- columns )
212     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
213     cleave <blas-matrix-rowcol-sequence> ;
214 : (Mrows) ( A -- rows )
215     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
216     cleave <blas-matrix-rowcol-sequence> ;
217
218 : Mrows ( A -- rows )
219     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
220 : Mcols ( A -- cols )
221     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
222
223 : n*M! ( n A -- A=n*A )
224     [ (Mcols) [ n*V! drop ] with each ] keep ;
225
226 : n*M ( n A -- n*A )
227     clone n*M! ; inline
228
229 : M*n ( A n -- A*n )
230     swap n*M ; inline
231 : M/n ( A n -- A/n )
232     recip swap n*M ; inline
233
234 : Mtranspose ( matrix -- matrix^T )
235     [ {
236         [ underlying>> ]
237         [ ld>> ] [ rows>> ]
238         [ cols>> ]
239         [ transpose>> not ]
240     } cleave ] keep (blas-matrix-like) ;
241
242 M: blas-matrix-base equal?
243     {
244         [ [ Mwidth ] bi@ = ]
245         [ [ Mcols ] bi@ [ = ] 2all? ]
246     } 2&& ;
247
248 <<
249
250 FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
251
252 VECTOR      IS ${TYPE}-blas-vector
253 <VECTOR>    IS <${TYPE}-blas-vector>
254 >ARRAY      IS >${TYPE}-array
255 XGEMV       IS ${T}GEMV
256 XGEMM       IS ${T}GEMM
257 XGERU       IS ${T}GER${U}
258 XGERC       IS ${T}GER${C}
259
260 MATRIX      DEFINES-CLASS ${TYPE}-blas-matrix
261 <MATRIX>    DEFINES <${TYPE}-blas-matrix>
262 >MATRIX     DEFINES >${TYPE}-blas-matrix
263
264 t           [ T >lower ]
265
266 XMATRIX{    DEFINES ${t}matrix{
267
268 WHERE
269
270 TUPLE: MATRIX < blas-matrix-base ;
271 : <MATRIX> ( underlying ld rows cols transpose -- matrix )
272     MATRIX boa ; inline
273
274 M: MATRIX element-type
275     drop TYPE ;
276 M: MATRIX (blas-matrix-like)
277     drop <MATRIX> ;
278 M: VECTOR (blas-matrix-like)
279     drop <MATRIX> ;
280 M: MATRIX (blas-vector-like)
281     drop <VECTOR> ;
282
283 : >MATRIX ( arrays -- matrix )
284     [ >ARRAY underlying>> ] (>matrix) <MATRIX> ;
285
286 M: VECTOR n*M.V+n*V!
287     (prepare-gemv) [ XGEMV ] dip ;
288 M: MATRIX n*M.M+n*M!
289     (prepare-gemm) [ XGEMM ] dip ;
290 M: MATRIX n*V(*)V+M!
291     (prepare-ger) [ XGERU ] dip ;
292 M: MATRIX n*V(*)Vconj+M!
293     (prepare-ger) [ XGERC ] dip ;
294
295 SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
296
297 M: MATRIX pprint-delims
298     drop \ XMATRIX{ \ } ;
299
300 ;FUNCTOR
301
302
303 : define-real-blas-matrix ( TYPE T -- )
304     "" "" (define-blas-matrix) ;
305 : define-complex-blas-matrix ( TYPE T -- )
306     "U" "C" (define-blas-matrix) ;
307
308 "float"          "S" define-real-blas-matrix
309 "double"         "D" define-real-blas-matrix
310 "complex-float"  "C" define-complex-blas-matrix
311 "complex-double" "Z" define-complex-blas-matrix
312
313 >>
314
315 M: blas-matrix-base >pprint-sequence Mrows ;
316 M: blas-matrix-base pprint* pprint-object ;