]> gitweb.factorcode.org Git - factor.git/blob - basis/math/blas/matrices/matrices.factor
Specialized array overhaul
[factor.git] / basis / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types arrays byte-arrays combinators
2 combinators.short-circuit fry kernel locals macros
3 math math.blas.ffi math.blas.vectors math.blas.vectors.private
4 math.complex math.functions math.order functors words
5 sequences sequences.merged sequences.private shuffle
6 parser prettyprint.backend prettyprint.custom ascii
7 specialized-arrays ;
8 SPECIALIZED-ARRAY: float
9 SPECIALIZED-ARRAY: double
10 SPECIALIZED-ARRAY: complex-float
11 SPECIALIZED-ARRAY: complex-double
12 IN: math.blas.matrices
13
14 TUPLE: blas-matrix-base underlying ld rows cols transpose ;
15
16 : Mtransposed? ( matrix -- ? )
17     transpose>> ; inline
18 : Mwidth ( matrix -- width )
19     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
20 : Mheight ( matrix -- height )
21     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
22
23 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
24 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
25 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
26 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
27
28 <PRIVATE
29
30 : (blas-transpose) ( matrix -- integer )
31     transpose>> [ "T" ] [ "N" ] if ;
32
33 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
34
35 : (validate-gemv) ( A x y -- )
36     {
37         [ drop [ Mwidth  ] [ length>> ] bi* = ]
38         [ nip  [ Mheight ] [ length>> ] bi* = ]
39     } 3&&
40     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
41     unless ;
42
43 :: (prepare-gemv)
44     ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
45                           y )
46     A x y (validate-gemv)
47     A (blas-transpose)
48     A rows>>
49     A cols>>
50     alpha
51     A
52     A ld>>
53     x
54     x inc>>
55     beta
56     y
57     y inc>>
58     y ; inline
59
60 : (validate-ger) ( x y A -- )
61     {
62         [ [ length>> ] [ drop     ] [ Mheight ] tri* = ]
63         [ [ drop     ] [ length>> ] [ Mwidth  ] tri* = ]
64     } 3&&
65     [ "Mismatched vertices and matrix in vector outer product" throw ]
66     unless ;
67
68 :: (prepare-ger)
69     ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
70                      A )
71     x y A (validate-ger)
72     A rows>>
73     A cols>>
74     alpha
75     x
76     x inc>>
77     y
78     y inc>>
79     A
80     A ld>>
81     A f >>transpose ; inline
82
83 : (validate-gemm) ( A B C -- )
84     {
85         [ [ Mwidth  ] [ Mheight ] [ drop    ] tri* = ]
86         [ [ Mheight ] [ drop    ] [ Mheight ] tri* = ]
87         [ [ drop    ] [ Mwidth  ] [ Mwidth  ] tri* = ]
88     } 3&&
89     [ "Mismatched matrices in matrix multiplication" throw ]
90     unless ;
91
92 :: (prepare-gemm)
93     ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
94                           C )
95     A B C (validate-gemm)
96     A (blas-transpose)
97     B (blas-transpose)
98     C rows>>
99     C cols>>
100     A Mwidth
101     alpha
102     A
103     A ld>>
104     B
105     B ld>>
106     beta
107     C
108     C ld>>
109     C f >>transpose ; inline
110
111 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
112     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
113
114 PRIVATE>
115
116 ! XXX should do a dense clone
117 M: blas-matrix-base clone
118     [ 
119         [ {
120             [ underlying>> ]
121             [ ld>> ]
122             [ cols>> ]
123             [ element-type heap-size ]
124         } cleave * * memory>byte-array ]
125         [ {
126             [ ld>> ]
127             [ rows>> ]
128             [ cols>> ]
129             [ transpose>> ]
130         } cleave ]
131         bi
132     ] keep (blas-matrix-like) ;
133
134 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
135 : <empty-matrix> ( rows cols exemplar -- matrix )
136     [ element-type heap-size * * <byte-array> ]
137     [ 2drop ]
138     [ f swap (blas-matrix-like) ] 3tri ;
139
140 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
141     clone n*M.V+n*V! ;
142 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
143     clone n*V(*)V+M! ;
144 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
145     clone n*V(*)Vconj+M! ;
146 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
147     clone n*M.M+n*M! ;
148
149 : n*M.V ( alpha A x -- alpha*A.x )
150     1.0 2over [ Mheight ] dip <empty-vector>
151     n*M.V+n*V! ; inline
152
153 : M.V ( A x -- A.x )
154     1.0 -rot n*M.V ; inline
155
156 : n*V(*)V ( alpha x y -- alpha*x(*)y )
157     2dup [ length>> ] bi@ pick <empty-matrix>
158     n*V(*)V+M! ;
159 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
160     2dup [ length>> ] bi@ pick <empty-matrix>
161     n*V(*)Vconj+M! ;
162
163 : V(*) ( x y -- x(*)y )
164     1.0 -rot n*V(*)V ; inline
165 : V(*)conj ( x y -- x(*)yconj )
166     1.0 -rot n*V(*)Vconj ; inline
167
168 : n*M.M ( alpha A B -- alpha*A.B )
169     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix> 
170     1.0 swap n*M.M+n*M! ;
171
172 : M. ( A B -- A.B )
173     1.0 -rot n*M.M ; inline
174
175 :: (Msub) ( matrix row col height width -- data ld rows cols )
176     matrix ld>> col * row + matrix element-type heap-size *
177     matrix underlying>> <displaced-alien>
178     matrix ld>>
179     height
180     width ;
181
182 :: Msub ( matrix row col height width -- sub )
183     matrix dup transpose>>
184     [ col row width height ]
185     [ row col height width ] if (Msub)
186     matrix transpose>> matrix (blas-matrix-like) ;
187
188 TUPLE: blas-matrix-rowcol-sequence
189     parent inc rowcol-length rowcol-jump length ;
190 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
191
192 INSTANCE: blas-matrix-rowcol-sequence sequence
193
194 M: blas-matrix-rowcol-sequence length
195     length>> ;
196 M: blas-matrix-rowcol-sequence nth-unsafe
197     {
198         [
199             [ rowcol-jump>> ]
200             [ parent>> element-type heap-size ]
201             [ parent>> underlying>> ] tri
202             [ * * ] dip <displaced-alien>
203         ]
204         [ rowcol-length>> ]
205         [ inc>> ]
206         [ parent>> ]
207     } cleave (blas-vector-like) ;
208
209 : (Mcols) ( A -- columns )
210     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
211     cleave <blas-matrix-rowcol-sequence> ;
212 : (Mrows) ( A -- rows )
213     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
214     cleave <blas-matrix-rowcol-sequence> ;
215
216 : Mrows ( A -- rows )
217     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
218 : Mcols ( A -- cols )
219     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
220
221 : n*M! ( n A -- A=n*A )
222     [ (Mcols) [ n*V! drop ] with each ] keep ;
223
224 : n*M ( n A -- n*A )
225     clone n*M! ; inline
226
227 : M*n ( A n -- A*n )
228     swap n*M ; inline
229 : M/n ( A n -- A/n )
230     recip swap n*M ; inline
231
232 : Mtranspose ( matrix -- matrix^T )
233     [ {
234         [ underlying>> ]
235         [ ld>> ] [ rows>> ]
236         [ cols>> ]
237         [ transpose>> not ]
238     } cleave ] keep (blas-matrix-like) ;
239
240 M: blas-matrix-base equal?
241     {
242         [ [ Mwidth ] bi@ = ]
243         [ [ Mcols ] bi@ [ = ] 2all? ]
244     } 2&& ;
245
246 <<
247
248 FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
249
250 VECTOR      IS ${TYPE}-blas-vector
251 <VECTOR>    IS <${TYPE}-blas-vector>
252 >ARRAY      IS >${TYPE}-array
253 XGEMV       IS ${T}GEMV
254 XGEMM       IS ${T}GEMM
255 XGERU       IS ${T}GER${U}
256 XGERC       IS ${T}GER${C}
257
258 MATRIX      DEFINES-CLASS ${TYPE}-blas-matrix
259 <MATRIX>    DEFINES <${TYPE}-blas-matrix>
260 >MATRIX     DEFINES >${TYPE}-blas-matrix
261
262 t           [ T >lower ]
263
264 XMATRIX{    DEFINES ${t}matrix{
265
266 WHERE
267
268 TUPLE: MATRIX < blas-matrix-base ;
269 : <MATRIX> ( underlying ld rows cols transpose -- matrix )
270     MATRIX boa ; inline
271
272 M: MATRIX element-type
273     drop TYPE ;
274 M: MATRIX (blas-matrix-like)
275     drop <MATRIX> ;
276 M: VECTOR (blas-matrix-like)
277     drop <MATRIX> ;
278 M: MATRIX (blas-vector-like)
279     drop <VECTOR> ;
280
281 : >MATRIX ( arrays -- matrix )
282     [ >ARRAY underlying>> ] (>matrix) <MATRIX> ;
283
284 M: VECTOR n*M.V+n*V!
285     (prepare-gemv) [ XGEMV ] dip ;
286 M: MATRIX n*M.M+n*M!
287     (prepare-gemm) [ XGEMM ] dip ;
288 M: MATRIX n*V(*)V+M!
289     (prepare-ger) [ XGERU ] dip ;
290 M: MATRIX n*V(*)Vconj+M!
291     (prepare-ger) [ XGERC ] dip ;
292
293 SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
294
295 M: MATRIX pprint-delims
296     drop \ XMATRIX{ \ } ;
297
298 ;FUNCTOR
299
300
301 : define-real-blas-matrix ( TYPE T -- )
302     "" "" (define-blas-matrix) ;
303 : define-complex-blas-matrix ( TYPE T -- )
304     "U" "C" (define-blas-matrix) ;
305
306 "float"          "S" define-real-blas-matrix
307 "double"         "D" define-real-blas-matrix
308 "complex-float"  "C" define-complex-blas-matrix
309 "complex-double" "Z" define-complex-blas-matrix
310
311 >>
312
313 M: blas-matrix-base >pprint-sequence Mrows ;
314 M: blas-matrix-base pprint* pprint-object ;