1 ! Copyright (C) 2019 HMC Clinic.
2 ! See http://factorcode.org/license.txt for BSD license.
4 USING: accessors alien alien.c-types alien.data arrays combinators
5 grouping kernel math math.functions ranges math.vectors
6 math.vectors.simd multi-methods parser prettyprint.custom sequences sequences.extras
7 sequences.private specialized-arrays typed ;
9 QUALIFIED-WITH: alien.c-types c
10 SPECIALIZED-ARRAY: c:float
11 SPECIALIZED-ARRAY: float-4
14 ! Tensor class definition
20 ERROR: non-positive-shape-error shape ;
21 ERROR: shape-mismatch-error shape1 shape2 ;
22 ERROR: non-uniform-seq-error seq ;
23 ERROR: dimension-mismatch-error tensor-dim index-dim ;
27 ! Check that the shape has only positive values
28 : check-shape ( shape -- shape )
29 dup [ 1 < ] map-find drop [ non-positive-shape-error ] when ;
31 ! Construct a tensor of zeros
32 : <tensor> ( shape seq -- tensor )
35 ! Creates a freshly-allocated float-array with the desired c-type values
36 : >float-array ( seq -- float-array )
39 : repetition ( shape const -- tensor )
40 [ check-shape dup product ] dip <repetition>
41 >float-array <tensor> ;
45 ! Construct a tensor of zeros
46 : zeros ( shape -- tensor )
49 ! Construct a tensor of ones
50 : ones ( shape -- tensor )
53 ! Construct a one-dimensional tensor with values start, start+step,
54 ! ..., stop (inclusive)
55 : arange ( a b step -- tensor )
56 <range> [ length >fixnum 1array ] keep >float-array <tensor> ;
58 ! Construct a tensor with vec { 0 1 2 ... } and reshape to the desired shape
59 : naturals ( shape -- tensor )
60 check-shape dup product [0..b) >float-array <tensor> ;
62 ! Construct a tensor without initializing its values
63 : (tensor) ( shape -- tensor )
64 dup product (float-array) <tensor> ;
68 : check-reshape ( shape1 shape2 -- shape1 shape2 )
69 2dup [ product ] bi@ = [ shape-mismatch-error ] unless ;
73 ! Reshape the tensor to conform to the new shape
74 : reshape ( tensor shape -- tensor )
75 [ dup shape>> ] [ check-shape ] bi* check-reshape nip >>shape ;
77 ! Flatten the tensor so that it is only one-dimensional
78 : flatten ( tensor -- tensor )
79 dup shape>> product { } 1sequence >>shape ;
81 ! outputs the number of dimensions of a tensor
82 : dims ( tensor -- n )
85 ! Turn into Factor ND array form
86 ! Source: shaped-array>array
87 TYPED: tensor>array ( tensor: tensor -- seq: array )
88 [ vec>> >array ] [ shape>> ] bi
89 [ rest-slice reverse [ group ] each ] unless-empty ;
92 ! recursively finds shape of nested array
93 ! assumes properly shaped array (all sub-arrays are same size)
94 :: find-shape ( seq shape -- shape' )
95 seq empty? [ { 0 } ] [
96 ! add length of seq element to shape
97 shape seq length 1array append :> shape'
98 ! base case: check if the first element is a seq
101 ! is a sequence: recurse on 1st element
102 [ 1st shape' find-shape ]
103 ! not a sequence: return shape'
108 ! turns a nested array into a tensor
109 :: >tensor ( seq -- tensor )
111 seq { } find-shape :> shape
116 ] each-integer :> flatseq
117 ! check that the size is good
118 shape product flatseq length =
119 [ seq non-uniform-seq-error ] unless
121 shape flatseq >float-array <tensor> ;
123 SYNTAX: t{ \ } [ >tensor ] parse-literal ;
126 syntax:M: tensor pprint-delims drop \ t{ \ } ;
127 syntax:M: tensor >pprint-sequence tensor>array ;
128 syntax:M: tensor pprint* pprint-object ;
132 ! turns a shape into a list of things by which to multiply
133 ! indices to get a full index (e.g. { 2 3 4 } -> { 12 4 1 })
134 : ind-mults ( shape -- seq )
135 <reversed> 1 swap [ swap [ * ] keep ] map nip reverse ;
137 ! turns a num/seq index & tensor into num index & tensor
138 ! also throws a dimension mismatch if seq & tens shape>> arent the same len
139 : num-index ( n/seq tensor -- n tensor )
140 ! check form of index (num or seq)
142 [ ! if array, first check if it's a valid index
143 2dup [ shape>> length ] dip length 2dup =
144 [ dimension-mismatch-error ] unless 2drop
146 [ dup shape>> ind-mults ] dip [ * ] 2map-sum
152 ! Sequence protocol implementation
153 syntax:M: tensor clone [ shape>> clone ] [ vec>> clone ] bi <tensor> ;
155 syntax:M: tensor length vec>> length ;
157 syntax:M: tensor nth num-index vec>> nth ;
159 syntax:M: tensor nth-unsafe num-index vec>> nth-unsafe ;
161 syntax:M: tensor set-nth num-index vec>> set-nth ;
163 syntax:M: tensor set-nth-unsafe num-index vec>> set-nth-unsafe ;
165 syntax:M: tensor new-sequence
166 ! Check if the old and new tensors are the same size
167 shape>> 2dup product =
168 ! If so preserve the shape, otherwise create a 1D tensor
169 [ nip (tensor) ] [ drop 1array (tensor) ] if ;
171 syntax:M: tensor like
172 ! If the original sequence is already a tensor, we are done
176 [ dup [ length 1array ] dip <tensor> ] dip
180 2dup [ length ] bi@ = [ shape>> reshape ] [ drop ] if
183 syntax:M: tensor clone-like
184 ! If the original sequence is already a tensor, we just need to clone it
188 2dup [ length ] bi@ = [ shape>> reshape ] [ drop ] if
191 INSTANCE: tensor sequence
196 :: make-subseq ( arr start len -- arr )
198 c:float heap-size start *
199 ! Compute the starting pointer
200 arr underlying>> <displaced-alien>
201 ! Push length and type to create the new array
202 len c:float <c-direct-array> ; inline
204 : check-bop-shape ( shape1 shape2 -- shape )
205 2dup = [ shape-mismatch-error ] unless drop ;
207 ! Apply the binary operator bop to combine the tensors
208 TYPED:: t-bop ( tensor1: tensor tensor2: tensor quot: ( x y -- z ) -- tensor: tensor )
209 tensor1 shape>> tensor2 shape>> check-bop-shape
210 tensor1 vec>> tensor2 vec>> quot 2map <tensor> ; inline
212 ! Create an array of 4-element SIMD arrays for processing floats
213 : simd-for-bop ( array -- simd-array rest-slice/f )
214 dup length dup 4 mod [ drop f ] [ - cut-slice ] if-zero
215 [ float-4 cast-array ] dip ; inline
217 ! Create an array of 4-element SIMD arrays for processing floats
218 ! Tensor class definition
220 { first-slice float-array }
221 { simd-slice float-4-array }
222 { end-slice float-array } ;
224 :: (simd-slice) ( arr start len -- arr/f )
225 len [ float-array{ } ] [ drop arr start len make-subseq ] if-zero ; inline
227 :: <simd-slice> ( arr start -- simd-slice )
228 ! Compute the beginning
229 arr 0 start (simd-slice)
230 ! Compute the SIMD part
231 arr length start - :> len
233 arr start len end - (simd-slice) float-4 cast-array
235 arr dup length end - end (simd-slice)
236 simd-slice boa ; inline
238 ! Apply the binary operators simd-quot and quot to quickly combine the tensors
239 :: t-bop-simd ( tensor1 tensor2 simd-quot: ( x y -- z ) quot: ( x y -- z ) -- tensor )
240 tensor1 shape>> tensor2 shape>> check-bop-shape
241 tensor1 vec>> tensor2 vec>>
242 dup length (float-array) dup :> vec3
243 [ simd-for-bop ] tri@ :> ( simd1 rest1 simd2 rest2 simd3 rest3 )
244 simd1 simd2 simd-quot simd3 2map-into
245 rest1 rest2 quot rest3 2map-into
246 vec3 <tensor> ; inline
248 ! Apply the operation to the tensor
249 TYPED:: t-uop ( tensor: tensor quot: ( x -- y ) -- tensor: tensor )
250 tensor vec>> quot map [ tensor shape>> ] dip <tensor> ; inline
252 ! Apply the binary operators simd-quot and quot to quickly combine a tensor and
254 :: t-uop-simd ( tensor n simd-quot: ( x y -- z ) quot: ( x y -- z ) -- tensor )
255 tensor dup [ shape>> ] [ vec>> ] bi*
256 dup length (float-array) dup :> vec2
257 [ simd-for-bop ] bi@ :> ( simd1 rest1 simd2 rest2 )
258 simd1 n n n n float-4-boa simd-quot curry simd2 map-into
259 rest1 n quot curry rest2 map-into
260 vec2 <tensor> ; inline
264 ! Add a tensor to either another tensor or a scalar
265 multi-methods:GENERIC: t+ ( x y -- tensor )
266 METHOD: t+ { tensor tensor } [ v+ ] [ + ] t-bop-simd ;
267 METHOD: t+ { tensor number } >float [ v+ ] [ + ] t-uop-simd ;
268 METHOD: t+ { number tensor } swap >float [ swap v+ ] [ swap + ] t-uop-simd ;
270 ! Subtraction between two tensors or a tensor and a scalar
271 multi-methods:GENERIC: t- ( x y -- tensor )
272 METHOD: t- { tensor tensor } [ v- ] [ - ] t-bop-simd ;
273 METHOD: t- { tensor number } >float [ v- ] [ - ] t-uop-simd ;
274 METHOD: t- { number tensor } swap >float [ swap v- ] [ swap - ] t-uop-simd ;
276 ! Multiply a tensor with either another tensor or a scalar
277 multi-methods:GENERIC: t* ( x y -- tensor )
278 METHOD: t* { tensor tensor } [ v* ] [ * ] t-bop-simd ;
279 METHOD: t* { tensor number } >float [ v* ] [ * ] t-uop-simd ;
280 METHOD: t* { number tensor } swap >float [ swap v* ] [ swap * ] t-uop-simd ;
282 ! Divide two tensors or a tensor and a scalar
283 multi-methods:GENERIC: t/ ( x y -- tensor )
284 METHOD: t/ { tensor tensor } [ v/ ] [ / ] t-bop-simd ;
285 METHOD: t/ { tensor number } >float [ v/ ] [ / ] t-uop-simd ;
286 METHOD: t/ { number tensor } swap >float [ swap v/ ] [ swap / ] t-uop-simd ;
288 ! Mod two tensors or a tensor and a scalar
289 multi-methods:GENERIC: t% ( x y -- tensor )
290 METHOD: t% { tensor tensor } [ mod ] t-bop ;
291 METHOD: t% { tensor number } >float [ mod ] curry t-uop ;
292 METHOD: t% { number tensor } [ >float ] dip [ mod ] with t-uop ;
294 ! Sum together all elements in the tensor
295 syntax:M: tensor sum vec>> 0 <simd-slice>
296 [ simd-slice>> [ sum ] map-sum ]
297 [ end-slice>> sum ] bi + ;
301 ! Also converts all elements of the sequence to tensors
302 :: check-concat-shape ( seq -- seq )
303 ! Compute the bottom shape of the first element in the sequence
304 seq first { } >tensor dup :> empty-tensor
305 like shape>> dup :> first-shape rest :> rest-shape
307 ! Compute the bottom shape of this element
308 empty-tensor like dup shape>> rest
309 ! Compare; if they are different, throw an error
310 rest-shape = [ shape>> first-shape swap shape-mismatch-error ] unless
313 ! Also converts all elements of the sequence to tensors
314 :: check-stack-shape ( seq -- seq )
315 ! Compute the bottom shape of the first element in the sequence
316 seq first { } >tensor dup :> empty-tensor
317 like shape>> :> first-shape
319 ! Compute the bottom shape of this element
320 empty-tensor like dup shape>>
321 ! Compare; if they are different, throw an error
322 first-shape = [ shape>> first-shape swap shape-mismatch-error ] unless
325 ! Also converts all elements of the sequence to tensors
326 :: check-hstack-shape ( seq -- seq )
327 ! Compute the top shape of the first element in the sequence
328 seq first { } >tensor dup :> empty-tensor
329 like shape>> dup :> first-shape but-last :> but-last-shape
331 ! Compute the top shape of this element
332 empty-tensor like dup shape>> but-last
333 ! Compare; if they are different, throw an error
334 but-last-shape = [ shape>> first-shape swap shape-mismatch-error ] unless
337 : final-hstack-shape ( seq -- shape )
339 dup first shape>> but-last swap
340 ! Compute the last part of the shape
341 [ shape>> last ] map sum 1array append ;
343 ! Returns an guide for hstacking where the index corresponds to the postion
344 ! in the last dimension of the resulting tensor, and the elements are
345 ! { which tensor, len of tensor, index }
346 :: hstack-guide ( seq -- guide )
347 ! Compute the list of last shape parts
348 seq [ shape>> last ] map :> last-dims
349 ! Curr tensor and index in tensor
351 last-dims sum [0..b) [
352 drop :> old-t-ind :> last-dims-i
353 last-dims-i last-dims nth
355 ! If we need to move onto the next tensor
356 [ last-dims-i 1 + 0 ]
357 ! Otherwise, stay with the current tensor
358 [ drop last-dims-i old-t-ind ] if-zero
359 2dup [ dup last-dims nth ] dip 3array
363 ! Given a sequence of tensors, stack them across the last dimension
364 :: hstack-unsafe ( tseq -- tensor )
365 ! Create the final tensor
366 tseq final-hstack-shape (tensor)
367 ! Compute the guide information
368 tseq hstack-guide dup length :> repeat :> guide
371 ! First get the correct tensor
372 i repeat /mod guide nth
374 ! Now find the correct value within that tensor
375 [ [ second ] [ third ] bi -rot * + ] dip nth
378 ! Also converts all elements of the sequence to tensors
379 :: check-vstack-shape ( seq -- seq )
380 ! Compute the shape of the first sequence
381 seq first { } >tensor dup :> empty-tensor
382 like shape>> dup :> first-shape
383 ! Compute the index of the dimension to be stacked across
386 ! Convert this element to a tensor
387 empty-tensor like dup
389 shape>> first-shape [ = ] 2map
391 ! If the shapes differ in anything except the second-to-last dimension
392 ! this sequence cannot be vstacked
393 t [ = ] reduce [ shape>> first-shape swap shape-mismatch-error ] unless
396 ! Compute the shape after the vstack has been completed
397 :: final-vstack-shape ( seq -- shape )
398 ! Compute the new second-to-last dimension
399 seq first dims 2 - :> vdim
400 seq [ shape>> vdim swap nth ] map-sum
401 ! Combine it to create the new shape
402 seq first shape>> clone :> new-shape
403 vdim new-shape set-nth
406 ! Combine the second-to-last and last dimensions of each tensor for stacking
407 :: reshape-for-vstack ( seq -- seq )
408 seq first dims 2 - :> vdim
410 dup shape>> vdim cut product 1array append >>shape
416 ! Concatenation operations
417 ! Concatenate across the last dimension
418 : t-concat ( seq -- tensor )
420 ! Compute the final shape
422 ! Compute the first dimension
423 [ [ shape>> first ] map-sum 1array ]
424 ! Compute the other dimensions
425 [ first shape>> rest ] bi append
427 ! Concatenate all of the float-arrays
428 [ [ vec>> ] map concat ] bi <tensor> ;
430 : stack ( seq -- tensor )
432 ! Compute the new shape
433 [ [ length 1array ] [ first shape>> ] bi append ]
434 ! Concatenate all of the tensors
435 [ [ vec>> ] map concat ] bi <tensor> ;
437 : hstack ( seq -- tensor )
438 ! Check shape and convert everything to tensors
439 check-hstack-shape hstack-unsafe ;
441 : vstack ( seq -- tensor )
442 ! Check shape and convert everything to tensors
444 ! Find the final shape
445 [ final-vstack-shape ]
446 ! Reshape each of the tensors and stack
447 [ reshape-for-vstack hstack-unsafe ] bi
448 ! Finally reshape and return
453 ! Check that the tensor has an acceptable shape for matrix multiplication
454 : check-matmul-shape ( tensor1 tensor2 -- )
455 [let [ shape>> ] bi@ :> shape2 :> shape1
456 ! Check that the matrices can be multiplied
457 shape1 last shape2 [ length 2 - ] keep nth =
458 ! Check that the other dimensions are equal
459 shape1 2 head* shape2 2 head* = and
460 ! If either is false, raise an error
461 [ shape1 shape2 shape-mismatch-error ] unless ] ;
463 ! Slice out a row from the array
464 : row ( arr n i p -- slice )
465 ! Compute the starting index
467 ! Compute the ending index
472 ! much quicker transpose for 2d tensors
473 TYPED:: 2d-transpose ( tensor: tensor -- tensor': tensor )
474 tensor shape>> :> old-shape
476 old-shape first2 :> ( s1 s2 )
477 ! loop through new tensor
478 old-shape reverse dup product <iota> [
479 ! find y*b val in original tensor
481 ! find x val in original tensor
482 [ s2 /mod ] dip + nip
483 ! get that index in original tensor
485 ] float-array{ } map-as <tensor> ;
487 ! Perform matrix multiplication muliplying an
488 ! mxn matrix with a nxp matrix
489 TYPED:: 2d-matmul ( vec1: float-array vec2: float-array res: float-array
490 m: fixnum n: fixnum p: fixnum -- )
491 ! For each element in the range, we want to compute the dot product of the
492 ! corresponding row and column
493 ! Transpose vec2 so that we are doing row * row (as opposed to row * col)
494 { n p } vec2 <tensor> 2d-transpose vec>> :> vec2
499 vec1 in n make-subseq
502 vec2 j n * n make-subseq
504 ip j + res set-nth-unsafe
509 ! Perform matrix multiplication muliplying an
510 ! mxn matrix with a nxp matrix
511 TYPED:: 2d-matmul-mixed ( vec1: float-array vec2: float-array res: float-array
512 m: fixnum n: fixnum p: fixnum start: fixnum -- )
513 ! For each element in the range, we want to compute the dot product of the
514 ! corresponding row and column
515 ! Transpose vec2 so that we are doing row * row (as opposed to row * col)
516 { n p } vec2 <tensor> 2d-transpose vec>> :> vec2
518 ! Compute the location in the float-array each 2D matrix will start at
519 start m n * * :> start1
520 start n p * * :> start2
524 4 4 in start1 + 4 mod - swap mod :> in4m
526 vec1 in n make-subseq :> sub1
527 sub1 in4m <simd-slice> :> slice1
530 4 4 jn 4 mod - swap mod :> jn4m
531 vec2 jn n make-subseq
533 jn4m <simd-slice> slice1 swap
534 2dup [ first-slice>> ] bi@ 0.0 [ * + ] 2reduce
535 [ 2dup [ simd-slice>> ] bi@ ] dip [ vdot + ] 2reduce
536 [ [ end-slice>> ] bi@ ] dip [ * + ] 2reduce
541 ip j + res set-nth-unsafe
545 ! ! Perform matrix multiplication muliplying an
546 ! mxn matrix with a nxp matrix
547 ! Should only be called when n is a multiple of 4
548 TYPED:: 2d-matmul-simd ( vec1: float-array vec2: float-array
550 m: fixnum n: fixnum p: fixnum -- )
551 ! For each element in the range, we want to compute the dot product of the
552 ! corresponding row and column
553 ! Transpose vec2 so that we are doing row * row (as opposed to row * col)
554 { n p } vec2 <tensor> 2d-transpose vec>> :> vec2
559 vec1 in n make-subseq float-4 cast-array
562 vec2 j n * n make-subseq float-4 cast-array
563 0.0 [ vdot + ] 2reduce
564 ip j + res set-nth-unsafe
572 ! Perform matrix multiplication muliplying an
573 ! ...xmxn matrix with a ...xnxp matrix
574 TYPED:: matmul ( tensor1: tensor tensor2: tensor -- tensor3: tensor )
575 ! First check the shape
576 tensor1 tensor2 check-matmul-shape
578 ! Now save all of the sizes
579 tensor1 shape>> unclip-last-slice :> n
580 unclip-last-slice :> m :> top-shape
581 tensor2 shape>> last :> p
582 top-shape product :> top-prod
584 ! Create the shape of the resulting tensor
585 top-shape { m p } append
587 ! Now create the new float array to store the underlying result
588 dup product (float-array) :> vec3
590 ! Now update the tensor3 to contain the multiplied matricies
594 ! Compute vec1 using direct C arrays
595 tensor1 vec>> m n * i * m n * make-subseq
597 ! Compute vec2 and start2
598 tensor2 vec>> n p * i * n p * make-subseq
601 vec3 m p * i * m p * make-subseq
602 ! Push m, n, and p and multiply the arrays
604 { { [ n 4 mod 0 = ] [ 2d-matmul-simd ] }
605 { [ n 4 < ] [ 2d-matmul ] }
606 [ i 2d-matmul-mixed ]
612 ! Transpose an n-dimensional tensor by flipping the axes
613 TYPED:: transpose ( tensor: tensor -- tensor': tensor )
614 tensor shape>> length 2 =
615 [ tensor 2d-transpose ]
616 [ tensor shape>> :> old-shape
618 old-shape reverse :> new-shape
619 old-shape ind-mults :> mults
620 ! loop through new tensor
621 new-shape dup product <iota> [
622 ! find index in original tensor
623 old-shape mults [ [ /mod ] dip * ] 2map-sum nip
624 ! get that index in original tensor
626 ] float-array{ } map-as <tensor>