]> gitweb.factorcode.org Git - factor.git/blob - basis/math/blas/matrices/matrices.factor
move some allocation words that don't really have much to do with c types out of...
[factor.git] / basis / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types alien.data arrays
2 byte-arrays combinators combinators.short-circuit fry
3 kernel locals macros math math.blas.ffi math.blas.vectors
4 math.blas.vectors.private math.complex math.functions
5 math.order functors words sequences sequences.merged
6 sequences.private shuffle parser prettyprint.backend
7 prettyprint.custom ascii specialized-arrays ;
8 FROM: alien.c-types => float ;
9 SPECIALIZED-ARRAY: float
10 SPECIALIZED-ARRAY: double
11 SPECIALIZED-ARRAY: complex-float
12 SPECIALIZED-ARRAY: complex-double
13 IN: math.blas.matrices
14
15 TUPLE: blas-matrix-base underlying ld rows cols transpose ;
16
17 : Mtransposed? ( matrix -- ? )
18     transpose>> ; inline
19 : Mwidth ( matrix -- width )
20     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
21 : Mheight ( matrix -- height )
22     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
23
24 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
25 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
26 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
27 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
28
29 <PRIVATE
30
31 : (blas-transpose) ( matrix -- integer )
32     transpose>> [ "T" ] [ "N" ] if ;
33
34 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
35
36 : (validate-gemv) ( A x y -- )
37     {
38         [ drop [ Mwidth  ] [ length>> ] bi* = ]
39         [ nip  [ Mheight ] [ length>> ] bi* = ]
40     } 3&&
41     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
42     unless ;
43
44 :: (prepare-gemv)
45     ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
46                           y )
47     A x y (validate-gemv)
48     A (blas-transpose)
49     A rows>>
50     A cols>>
51     alpha
52     A
53     A ld>>
54     x
55     x inc>>
56     beta
57     y
58     y inc>>
59     y ; inline
60
61 : (validate-ger) ( x y A -- )
62     {
63         [ [ length>> ] [ drop     ] [ Mheight ] tri* = ]
64         [ [ drop     ] [ length>> ] [ Mwidth  ] tri* = ]
65     } 3&&
66     [ "Mismatched vertices and matrix in vector outer product" throw ]
67     unless ;
68
69 :: (prepare-ger)
70     ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
71                      A )
72     x y A (validate-ger)
73     A rows>>
74     A cols>>
75     alpha
76     x
77     x inc>>
78     y
79     y inc>>
80     A
81     A ld>>
82     A f >>transpose ; inline
83
84 : (validate-gemm) ( A B C -- )
85     {
86         [ [ Mwidth  ] [ Mheight ] [ drop    ] tri* = ]
87         [ [ Mheight ] [ drop    ] [ Mheight ] tri* = ]
88         [ [ drop    ] [ Mwidth  ] [ Mwidth  ] tri* = ]
89     } 3&&
90     [ "Mismatched matrices in matrix multiplication" throw ]
91     unless ;
92
93 :: (prepare-gemm)
94     ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
95                           C )
96     A B C (validate-gemm)
97     A (blas-transpose)
98     B (blas-transpose)
99     C rows>>
100     C cols>>
101     A Mwidth
102     alpha
103     A
104     A ld>>
105     B
106     B ld>>
107     beta
108     C
109     C ld>>
110     C f >>transpose ; inline
111
112 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
113     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
114
115 PRIVATE>
116
117 ! XXX should do a dense clone
118 M: blas-matrix-base clone
119     [ 
120         [ {
121             [ underlying>> ]
122             [ ld>> ]
123             [ cols>> ]
124             [ element-type heap-size ]
125         } cleave * * memory>byte-array ]
126         [ {
127             [ ld>> ]
128             [ rows>> ]
129             [ cols>> ]
130             [ transpose>> ]
131         } cleave ]
132         bi
133     ] keep (blas-matrix-like) ;
134
135 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
136 : <empty-matrix> ( rows cols exemplar -- matrix )
137     [ element-type heap-size * * <byte-array> ]
138     [ 2drop ]
139     [ f swap (blas-matrix-like) ] 3tri ;
140
141 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
142     clone n*M.V+n*V! ;
143 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
144     clone n*V(*)V+M! ;
145 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
146     clone n*V(*)Vconj+M! ;
147 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
148     clone n*M.M+n*M! ;
149
150 : n*M.V ( alpha A x -- alpha*A.x )
151     1.0 2over [ Mheight ] dip <empty-vector>
152     n*M.V+n*V! ; inline
153
154 : M.V ( A x -- A.x )
155     1.0 -rot n*M.V ; inline
156
157 : n*V(*)V ( alpha x y -- alpha*x(*)y )
158     2dup [ length>> ] bi@ pick <empty-matrix>
159     n*V(*)V+M! ;
160 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
161     2dup [ length>> ] bi@ pick <empty-matrix>
162     n*V(*)Vconj+M! ;
163
164 : V(*) ( x y -- x(*)y )
165     1.0 -rot n*V(*)V ; inline
166 : V(*)conj ( x y -- x(*)yconj )
167     1.0 -rot n*V(*)Vconj ; inline
168
169 : n*M.M ( alpha A B -- alpha*A.B )
170     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix> 
171     1.0 swap n*M.M+n*M! ;
172
173 : M. ( A B -- A.B )
174     1.0 -rot n*M.M ; inline
175
176 :: (Msub) ( matrix row col height width -- data ld rows cols )
177     matrix ld>> col * row + matrix element-type heap-size *
178     matrix underlying>> <displaced-alien>
179     matrix ld>>
180     height
181     width ;
182
183 :: Msub ( matrix row col height width -- sub )
184     matrix dup transpose>>
185     [ col row width height ]
186     [ row col height width ] if (Msub)
187     matrix transpose>> matrix (blas-matrix-like) ;
188
189 TUPLE: blas-matrix-rowcol-sequence
190     parent inc rowcol-length rowcol-jump length ;
191 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
192
193 INSTANCE: blas-matrix-rowcol-sequence sequence
194
195 M: blas-matrix-rowcol-sequence length
196     length>> ;
197 M: blas-matrix-rowcol-sequence nth-unsafe
198     {
199         [
200             [ rowcol-jump>> ]
201             [ parent>> element-type heap-size ]
202             [ parent>> underlying>> ] tri
203             [ * * ] dip <displaced-alien>
204         ]
205         [ rowcol-length>> ]
206         [ inc>> ]
207         [ parent>> ]
208     } cleave (blas-vector-like) ;
209
210 : (Mcols) ( A -- columns )
211     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
212     cleave <blas-matrix-rowcol-sequence> ;
213 : (Mrows) ( A -- rows )
214     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
215     cleave <blas-matrix-rowcol-sequence> ;
216
217 : Mrows ( A -- rows )
218     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
219 : Mcols ( A -- cols )
220     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
221
222 : n*M! ( n A -- A=n*A )
223     [ (Mcols) [ n*V! drop ] with each ] keep ;
224
225 : n*M ( n A -- n*A )
226     clone n*M! ; inline
227
228 : M*n ( A n -- A*n )
229     swap n*M ; inline
230 : M/n ( A n -- A/n )
231     recip swap n*M ; inline
232
233 : Mtranspose ( matrix -- matrix^T )
234     [ {
235         [ underlying>> ]
236         [ ld>> ] [ rows>> ]
237         [ cols>> ]
238         [ transpose>> not ]
239     } cleave ] keep (blas-matrix-like) ;
240
241 M: blas-matrix-base equal?
242     {
243         [ [ Mwidth ] bi@ = ]
244         [ [ Mcols ] bi@ [ = ] 2all? ]
245     } 2&& ;
246
247 <<
248
249 FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
250
251 VECTOR      IS ${TYPE}-blas-vector
252 <VECTOR>    IS <${TYPE}-blas-vector>
253 >ARRAY      IS >${TYPE}-array
254 XGEMV       IS ${T}GEMV
255 XGEMM       IS ${T}GEMM
256 XGERU       IS ${T}GER${U}
257 XGERC       IS ${T}GER${C}
258
259 MATRIX      DEFINES-CLASS ${TYPE}-blas-matrix
260 <MATRIX>    DEFINES <${TYPE}-blas-matrix>
261 >MATRIX     DEFINES >${TYPE}-blas-matrix
262
263 t           [ T >lower ]
264
265 XMATRIX{    DEFINES ${t}matrix{
266
267 WHERE
268
269 TUPLE: MATRIX < blas-matrix-base ;
270 : <MATRIX> ( underlying ld rows cols transpose -- matrix )
271     MATRIX boa ; inline
272
273 M: MATRIX element-type
274     drop TYPE ;
275 M: MATRIX (blas-matrix-like)
276     drop <MATRIX> ;
277 M: VECTOR (blas-matrix-like)
278     drop <MATRIX> ;
279 M: MATRIX (blas-vector-like)
280     drop <VECTOR> ;
281
282 : >MATRIX ( arrays -- matrix )
283     [ >ARRAY underlying>> ] (>matrix) <MATRIX> ;
284
285 M: VECTOR n*M.V+n*V!
286     (prepare-gemv) [ XGEMV ] dip ;
287 M: MATRIX n*M.M+n*M!
288     (prepare-gemm) [ XGEMM ] dip ;
289 M: MATRIX n*V(*)V+M!
290     (prepare-ger) [ XGERU ] dip ;
291 M: MATRIX n*V(*)Vconj+M!
292     (prepare-ger) [ XGERC ] dip ;
293
294 SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
295
296 M: MATRIX pprint-delims
297     drop \ XMATRIX{ \ } ;
298
299 ;FUNCTOR
300
301
302 : define-real-blas-matrix ( TYPE T -- )
303     "" "" (define-blas-matrix) ;
304 : define-complex-blas-matrix ( TYPE T -- )
305     "U" "C" (define-blas-matrix) ;
306
307 "float"          "S" define-real-blas-matrix
308 "double"         "D" define-real-blas-matrix
309 "complex-float"  "C" define-complex-blas-matrix
310 "complex-double" "Z" define-complex-blas-matrix
311
312 >>
313
314 M: blas-matrix-base >pprint-sequence Mrows ;
315 M: blas-matrix-base pprint* pprint-object ;