1 ! Copyright (C) 2019 HMC Clinic.
2 ! See http://factorcode.org/license.txt for BSD license.
3 USING: accessors alien.c-types alien.data arrays
4 concurrency.combinators grouping kernel locals math.functions
5 math.ranges math.statistics math multi-methods quotations sequences
6 sequences.private specialized-arrays tensors.tensor-slice typed ;
7 QUALIFIED-WITH: alien.c-types c
8 SPECIALIZED-ARRAY: c:float
11 ! Tensor class definition
17 ERROR: non-positive-shape-error shape ;
18 ERROR: shape-mismatch-error shape1 shape2 ;
22 ! Check that the shape has only positive values
23 : check-shape ( shape -- shape )
24 dup [ 1 < ] map-find drop [ non-positive-shape-error ] when ;
26 ! Construct a tensor of zeros
27 : <tensor> ( shape seq -- tensor )
30 : >float-array ( seq -- float-array )
33 : repetition ( shape const -- tensor )
34 [ check-shape dup product ] dip <repetition>
35 >float-array <tensor> ;
39 ! Construct a tensor of zeros
40 : zeros ( shape -- tensor )
43 ! Construct a tensor of ones
44 : ones ( shape -- tensor )
47 ! Construct a one-dimensional tensor with values start, start+step,
48 ! ..., stop (inclusive)
49 : arange ( a b step -- tensor )
50 <range> [ length 1array ] keep >float-array <tensor> ;
52 ! Construct a tensors with vec { 0 1 2 ... } and reshape to the desired shape
53 : naturals ( shape -- tensor )
54 check-shape [ ] [ product [0,b) >float-array ] bi <tensor> ;
58 : check-reshape ( shape1 shape2 -- shape1 shape2 )
59 2dup [ product ] bi@ = [ shape-mismatch-error ] unless ;
63 ! Reshape the tensor to conform to the new shape
64 : reshape ( tensor shape -- tensor )
65 [ dup shape>> ] [ check-shape ] bi* check-reshape nip >>shape ;
67 ! Flatten the tensor so that it is only one-dimensional
68 : flatten ( tensor -- tensor )
70 product { } 1sequence >>shape ;
72 ! outputs the number of dimensions of a tensor
73 : dims ( tensor -- n )
76 ! Turn into Factor ND array form
77 ! Source: shaped-array>array
78 TYPED: tensor>array ( tensor: tensor -- seq: array )
79 [ vec>> >array ] [ shape>> ] bi
80 [ rest-slice reverse [ group ] each ] unless-empty ;
84 : check-bop-shape ( shape1 shape2 -- shape )
85 2dup = [ shape-mismatch-error ] unless drop ;
87 ! Apply the binary operator bop to combine the tensors
88 TYPED:: t-bop ( tensor1: tensor tensor2: tensor quot: ( x y -- z ) -- tensor: tensor )
89 tensor1 shape>> tensor2 shape>> check-bop-shape
90 tensor1 vec>> tensor2 vec>> quot 2map <tensor> ; inline
92 ! Apply the operation to the tensor
93 TYPED:: t-uop ( tensor: tensor quot: ( x -- y ) -- tensor: tensor )
94 tensor vec>> quot map [ tensor shape>> ] dip <tensor> ; inline
98 ! Add a tensor to either another tensor or a scalar
99 multi-methods:GENERIC: t+ ( x y -- tensor )
100 METHOD: t+ { tensor tensor } [ + ] t-bop ;
101 METHOD: t+ { tensor number } [ + ] curry t-uop ;
102 METHOD: t+ { number tensor } swap [ + ] curry t-uop ;
104 ! Subtraction between two tensors or a tensor and a scalar
105 multi-methods:GENERIC: t- ( x y -- tensor )
106 METHOD: t- { tensor tensor } [ - ] t-bop ;
107 METHOD: t- { tensor number } [ - ] curry t-uop ;
108 METHOD: t- { number tensor } swap [ swap - ] curry t-uop ;
110 ! Multiply a tensor with either another tensor or a scalar
111 multi-methods:GENERIC: t* ( x y -- tensor )
112 METHOD: t* { tensor tensor } [ * ] t-bop ;
113 METHOD: t* { tensor number } [ * ] curry t-uop ;
114 METHOD: t* { number tensor } swap [ * ] curry t-uop ;
116 ! Divide two tensors or a tensor and a scalar
117 multi-methods:GENERIC: t/ ( x y -- tensor )
118 METHOD: t/ { tensor tensor } [ / ] t-bop ;
119 METHOD: t/ { tensor number } [ / ] curry t-uop ;
120 METHOD: t/ { number tensor } swap [ swap / ] curry t-uop ;
122 ! Divide two tensors or a tensor and a scalar
123 multi-methods:GENERIC: t% ( x y -- tensor )
124 METHOD: t% { tensor tensor } [ mod ] t-bop ;
125 METHOD: t% { tensor number } [ mod ] curry t-uop ;
126 METHOD: t% { number tensor } swap [ swap mod ] curry t-uop ;
130 ! Check that the tensor has an acceptable shape for matrix multiplication
131 : check-matmul-shape ( tensor1 tensor2 -- )
132 [let [ shape>> ] bi@ :> shape2 :> shape1
133 ! Check that the matrices can be multiplied
134 shape1 last shape2 [ length 2 - ] keep nth =
135 ! Check that the other dimensions are equal
136 shape1 2 head* shape2 2 head* = and
137 ! If either is false, raise an error
138 [ shape1 shape2 shape-mismatch-error ] unless ] ;
140 ! Slice out a row from the array
141 : row ( arr n i p -- slice )
142 ! Compute the starting index
144 ! Compute the ending index
149 ! Perform matrix multiplication muliplying an
150 ! mxn matrix with a nxp matrix
151 TYPED:: 2d-matmul ( vec1: slice vec2: slice res: slice n: number p: number -- )
152 ! For each element in the range, we want to compute the dot product of the
153 ! corresponding row and column
157 [ [ vec1 n ] dip p row ]
159 ! [ p mod vec2 swap p every ] bi
160 [ p mod f p vec2 <step-slice> ] bi
161 ! Take the dot product
162 [ * ] [ + ] 2map-reduce
169 ! Perform matrix multiplication muliplying an
170 ! ...xmxn matrix with a ...xnxp matrix
171 TYPED:: matmul ( tensor1: tensor tensor2: tensor -- tensor3: tensor )
172 ! First check the shape
173 tensor1 tensor2 check-matmul-shape
175 ! Now save all of the sizes
176 tensor1 shape>> unclip-last-slice :> n
177 unclip-last-slice :> m :> top-shape
178 tensor2 shape>> last :> p
179 top-shape product :> rest
181 ! Now create the new tensor with { 0 ... m*p-1 } repeating
182 top-shape { m p } append naturals m p * t% :> tensor3
184 ! Now update the tensor3 to contain the multiplied matricies
189 m n * i * dup m n * + tensor1 vec>> <slice>
191 n p * i * dup n p * + tensor2 vec>> <slice>
192 ! Now make the resulting vector
193 m p * i * dup m p * + tensor3 vec>> <slice>
194 ! Push n and p and multiply the clices
201 ! helper for transpose: gets the turns a shape into a list of things
202 ! by which to multiply indices to get a full index
203 : ind-mults ( shape -- seq )
204 rest-slice <reversed> cum-product { 1 } prepend ;
206 ! helper for transpose: given shape, flat index, & mults for the shape, gives nd index
207 :: trans-index ( ind shape mults -- seq )
208 ! what we use to divide things
212 ! loop thru elements & indices of S (mod by elment m)
214 ! we divide by the product of the 1st n elements of S
215 S i head-slice product :> div
216 ! do not mod on the last index
217 i S length 1 - = not :> mod?
218 ! multiply accumulator by mults & sum
219 dup mults [ * ] 2map sum
220 ! subtract from ind & divide
223 mod? [ m mod ] [ ] if
224 ! append to accumulator
225 [ dup ] dip swap push
230 ! Transpose an n-dimensional tensor
231 TYPED:: transpose ( tensor: tensor -- tensor': tensor )
233 tensor shape>> reverse :> newshape
234 ! what we multiply by to get indices in the old tensor
235 tensor shape>> ind-mults :> old-mults
236 ! what we multiply to get indices in new tensor
237 newshape ind-mults :> mults
238 ! new tensor of correct shape
239 newshape naturals dup vec>>
240 [ ! go thru each index
241 ! find index in original tensor
242 newshape mults trans-index old-mults [ * ] 2map sum >fixnum
243 ! get that index in original tensor