]> gitweb.factorcode.org Git - factor.git/blob - extra/math/blas/matrices/matrices.factor
factor: trim using lists
[factor.git] / extra / math / blas / matrices / matrices.factor
1 USING: accessors alien alien.c-types alien.complex alien.data
2 ascii byte-arrays combinators combinators.short-circuit functors
3 kernel math math.blas.ffi math.blas.vectors
4 math.blas.vectors.private parser prettyprint.custom sequences
5 sequences.merged sequences.private specialized-arrays ;
6 FROM: alien.c-types => float ;
7 SPECIALIZED-ARRAY: float
8 SPECIALIZED-ARRAY: double
9 SPECIALIZED-ARRAY: complex-float
10 SPECIALIZED-ARRAY: complex-double
11 IN: math.blas.matrices
12
13 TUPLE: blas-matrix-base underlying ld rows cols transpose ;
14
15 : Mtransposed? ( matrix -- ? )
16     transpose>> ; inline
17 : Mwidth ( matrix -- width )
18     dup Mtransposed? [ rows>> ] [ cols>> ] if ; inline
19 : Mheight ( matrix -- height )
20     dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
21
22 GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
23 GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
24 GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
25 GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
26
27 <PRIVATE
28
29 : (blas-transpose) ( matrix -- integer )
30     transpose>> [ "T" ] [ "N" ] if ;
31
32 GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
33
34 : (validate-gemv) ( A x y -- )
35     {
36         [ drop [ Mwidth  ] [ length>> ] bi* = ]
37         [ nip  [ Mheight ] [ length>> ] bi* = ]
38     } 3&&
39     [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
40     unless ;
41
42 :: (prepare-gemv)
43     ( alpha A x beta y -- A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
44                           y )
45     A x y (validate-gemv)
46     A (blas-transpose)
47     A rows>>
48     A cols>>
49     alpha
50     A
51     A ld>>
52     x
53     x inc>>
54     beta
55     y
56     y inc>>
57     y ; inline
58
59 : (validate-ger) ( x y A -- )
60     {
61         [ [ length>> ] [ drop     ] [ Mheight ] tri* = ]
62         [ [ drop     ] [ length>> ] [ Mwidth  ] tri* = ]
63     } 3&&
64     [ "Mismatched vertices and matrix in vector outer product" throw ]
65     unless ;
66
67 :: (prepare-ger)
68     ( alpha x y A -- m n alpha x-data x-inc y-data y-inc A-data A-ld
69                      A )
70     x y A (validate-ger)
71     A rows>>
72     A cols>>
73     alpha
74     x
75     x inc>>
76     y
77     y inc>>
78     A
79     A ld>>
80     A f >>transpose ; inline
81
82 : (validate-gemm) ( A B C -- )
83     {
84         [ [ Mwidth  ] [ Mheight ] [ drop    ] tri* = ]
85         [ [ Mheight ] [ drop    ] [ Mheight ] tri* = ]
86         [ [ drop    ] [ Mwidth  ] [ Mwidth  ] tri* = ]
87     } 3&&
88     [ "Mismatched matrices in matrix multiplication" throw ]
89     unless ;
90
91 :: (prepare-gemm)
92     ( alpha A B beta C -- A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
93                           C )
94     A B C (validate-gemm)
95     A (blas-transpose)
96     B (blas-transpose)
97     C rows>>
98     C cols>>
99     A Mwidth
100     alpha
101     A
102     A ld>>
103     B
104     B ld>>
105     beta
106     C
107     C ld>>
108     C f >>transpose ; inline
109
110 : (>matrix) ( arrays >c-array -- c-array ld rows cols transpose )
111     '[ <merged> @ ] [ length dup ] [ first length ] tri f ; inline
112
113 PRIVATE>
114
115 ! XXX should do a dense clone
116 M: blas-matrix-base clone
117     [
118         [ {
119             [ underlying>> ]
120             [ ld>> ]
121             [ cols>> ]
122             [ element-type heap-size ]
123         } cleave * * memory>byte-array ]
124         [ {
125             [ ld>> ]
126             [ rows>> ]
127             [ cols>> ]
128             [ transpose>> ]
129         } cleave ]
130         bi
131     ] keep (blas-matrix-like) ;
132
133 ! XXX try rounding stride to next 128 bit bound for better vectorizin'
134 : <empty-matrix> ( rows cols exemplar -- matrix )
135     [ element-type heap-size * * <byte-array> ]
136     [ 2drop ]
137     [ [ f ] dip (blas-matrix-like) ] 3tri ;
138
139 : n*M.V+n*V ( alpha A x beta y -- alpha*A.x+b*y )
140     clone n*M.V+n*V! ;
141 : n*V(*)V+M ( alpha x y A -- alpha*x(*)y+A )
142     clone n*V(*)V+M! ;
143 : n*V(*)Vconj+M ( alpha x y A -- alpha*x(*)yconj+A )
144     clone n*V(*)Vconj+M! ;
145 : n*M.M+n*M ( alpha A B beta C -- alpha*A.B+beta*C )
146     clone n*M.M+n*M! ;
147
148 : n*M.V ( alpha A x -- alpha*A.x )
149     1.0 2over [ Mheight ] dip <empty-vector>
150     n*M.V+n*V! ; inline
151
152 : M.V ( A x -- A.x )
153     [ 1.0 ] 2dip n*M.V ; inline
154
155 : n*V(*)V ( alpha x y -- alpha*x(*)y )
156     2dup [ length>> ] bi@ pick <empty-matrix>
157     n*V(*)V+M! ;
158 : n*V(*)Vconj ( alpha x y -- alpha*x(*)yconj )
159     2dup [ length>> ] bi@ pick <empty-matrix>
160     n*V(*)Vconj+M! ;
161
162 : V(*) ( x y -- x(*)y )
163     [ 1.0 ] 2dip n*V(*)V ; inline
164 : V(*)conj ( x y -- x(*)yconj )
165     [ 1.0 ] 2dip n*V(*)Vconj ; inline
166
167 : n*M.M ( alpha A B -- alpha*A.B )
168     2dup [ Mheight ] [ Mwidth ] bi* pick <empty-matrix>
169     [ 1.0 ] dip n*M.M+n*M! ;
170
171 : M. ( A B -- A.B )
172     [ 1.0 ] 2dip n*M.M ; inline
173
174 :: (Msub) ( matrix row col height width -- data ld rows cols )
175     matrix ld>> col * row + matrix element-type heap-size *
176     matrix underlying>> <displaced-alien>
177     matrix ld>>
178     height
179     width ;
180
181 :: Msub ( matrix row col height width -- sub )
182     matrix dup transpose>>
183     [ col row width height ]
184     [ row col height width ] if (Msub)
185     matrix transpose>> matrix (blas-matrix-like) ;
186
187 TUPLE: blas-matrix-rowcol-sequence
188     parent inc rowcol-length rowcol-jump length ;
189 C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
190
191 INSTANCE: blas-matrix-rowcol-sequence sequence
192
193 M: blas-matrix-rowcol-sequence length
194     length>> ;
195 M: blas-matrix-rowcol-sequence nth-unsafe
196     {
197         [
198             [ rowcol-jump>> ]
199             [ parent>> element-type heap-size ]
200             [ parent>> underlying>> ] tri
201             [ * * ] dip <displaced-alien>
202         ]
203         [ rowcol-length>> ]
204         [ inc>> ]
205         [ parent>> ]
206     } cleave (blas-vector-like) ;
207
208 : (Mcols) ( A -- columns )
209     { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
210     cleave <blas-matrix-rowcol-sequence> ;
211 : (Mrows) ( A -- rows )
212     { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
213     cleave <blas-matrix-rowcol-sequence> ;
214
215 : Mrows ( A -- rows )
216     dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
217 : Mcols ( A -- cols )
218     dup transpose>> [ (Mrows) ] [ (Mcols) ] if ;
219
220 : n*M! ( n A -- A=n*A )
221     [ (Mcols) [ n*V! drop ] with each ] keep ;
222
223 : n*M ( n A -- n*A )
224     clone n*M! ; inline
225
226 : M*n ( A n -- A*n )
227     swap n*M ; inline
228 : M/n ( A n -- A/n )
229     recip swap n*M ; inline
230
231 : Mtranspose ( matrix -- matrix^T )
232     [ {
233         [ underlying>> ]
234         [ ld>> ] [ rows>> ]
235         [ cols>> ]
236         [ transpose>> not ]
237     } cleave ] keep (blas-matrix-like) ;
238
239 M: blas-matrix-base equal?
240     {
241         [ and ]
242         [ [ Mwidth ] bi@ = ]
243         [ [ Mcols ] bi@ [ = ] 2all? ]
244     } 2&& ;
245
246 <<
247
248 <FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
249
250 VECTOR      IS ${TYPE}-blas-vector
251 <VECTOR>    IS <${TYPE}-blas-vector>
252 XGEMV       IS ${T}GEMV
253 XGEMM       IS ${T}GEMM
254 XGERU       IS ${T}GER${U}
255 XGERC       IS ${T}GER${C}
256
257 MATRIX      DEFINES-CLASS ${TYPE}-blas-matrix
258 <MATRIX>    DEFINES <${TYPE}-blas-matrix>
259 >MATRIX     DEFINES >${TYPE}-blas-matrix
260
261 t           [ T >lower ]
262
263 XMATRIX{    DEFINES ${t}matrix{
264
265 WHERE
266
267 TUPLE: MATRIX < blas-matrix-base ;
268 : <MATRIX> ( underlying ld rows cols transpose -- matrix )
269     MATRIX boa ; inline
270
271 M: MATRIX element-type
272     drop TYPE ;
273 M: MATRIX (blas-matrix-like)
274     drop <MATRIX> ;
275 M: VECTOR (blas-matrix-like)
276     drop <MATRIX> ;
277 M: MATRIX (blas-vector-like)
278     drop <VECTOR> ;
279
280 : >MATRIX ( arrays -- matrix )
281     [ TYPE >c-array underlying>> ] (>matrix) <MATRIX> ;
282
283 M: VECTOR n*M.V+n*V!
284     (prepare-gemv) [ XGEMV ] dip ;
285 M: MATRIX n*M.M+n*M!
286     (prepare-gemm) [ XGEMM ] dip ;
287 M: MATRIX n*V(*)V+M!
288     (prepare-ger) [ XGERU ] dip ;
289 M: MATRIX n*V(*)Vconj+M!
290     (prepare-ger) [ XGERC ] dip ;
291
292 SYNTAX: XMATRIX{ \ } [ >MATRIX ] parse-literal ;
293
294 M: MATRIX pprint-delims
295     drop \ XMATRIX{ \ } ;
296
297 ;FUNCTOR>
298
299
300 : define-real-blas-matrix ( TYPE T -- )
301     "" "" (define-blas-matrix) ;
302 : define-complex-blas-matrix ( TYPE T -- )
303     "U" "C" (define-blas-matrix) ;
304
305 float          "S" define-real-blas-matrix
306 double         "D" define-real-blas-matrix
307 complex-float  "C" define-complex-blas-matrix
308 complex-double "Z" define-complex-blas-matrix
309
310 >>
311
312 M: blas-matrix-base >pprint-sequence Mrows ;
313 M: blas-matrix-base pprint* pprint-object ;