1 ! Copyright (C) 2009 Slava Pestov.
2 ! See http://factorcode.org/license.txt for BSD license.
3 USING: words kernel make sequences effects sets kernel.private
4 accessors combinators math math.intervals math.vectors
5 math.vectors.conversion.backend
6 namespaces assocs fry splitting classes.algebra generalizations
7 locals compiler.tree.propagation.info ;
8 IN: math.vectors.specialization
10 SYMBOLS: -> +vector+ +any-vector+ +scalar+ +boolean+ +nonnegative+ +literal+ ;
12 : parent-vector-class ( type -- type' )
14 { [ dup simd-128 class<= ] [ drop simd-128 ] }
15 { [ dup simd-256 class<= ] [ drop simd-256 ] }
16 [ "Not a vector class" throw ]
19 : signature-for-schema ( array-type elt-type schema -- signature )
23 { +any-vector+ [ drop parent-vector-class ] }
25 { +boolean+ [ 2drop boolean ] }
26 { +nonnegative+ [ nip ] }
27 { +literal+ [ 2drop f ] }
31 : (specialize-vector-word) ( word array-type elt-type schema -- word' )
33 [ [ name>> ] [ [ name>> ] map "," join ] bi* "=>" glue f <word> ]
34 [ [ , \ declare , def>> % ] [ ] make ]
37 [ define-declared ] [ 2drop ] 3bi ;
39 : output-infos ( array-type elt-type schema -- value-infos )
42 { +vector+ [ drop <class-info> ] }
43 { +any-vector+ [ drop parent-vector-class <class-info> ] }
44 { +scalar+ [ nip <class-info> ] }
45 { +boolean+ [ 2drop boolean <class-info> ] }
50 dup complex class<= [ drop float ] when
51 [0,inf] <class/interval-info>
57 : record-output-signature ( word array-type elt-type schema -- word )
61 [ [ stack-effect in>> length '[ _ ndrop ] ] dip append ] 2tri
62 "outputs" set-word-prop ;
64 CONSTANT: vector-words
66 { [v-] { +vector+ +vector+ -> +vector+ } }
67 { distance { +vector+ +vector+ -> +nonnegative+ } }
68 { n*v { +scalar+ +vector+ -> +vector+ } }
69 { n+v { +scalar+ +vector+ -> +vector+ } }
70 { n-v { +scalar+ +vector+ -> +vector+ } }
71 { n/v { +scalar+ +vector+ -> +vector+ } }
72 { norm { +vector+ -> +nonnegative+ } }
73 { norm-sq { +vector+ -> +nonnegative+ } }
74 { normalize { +vector+ -> +vector+ } }
75 { v* { +vector+ +vector+ -> +vector+ } }
76 { vs* { +vector+ +vector+ -> +vector+ } }
77 { v*n { +vector+ +scalar+ -> +vector+ } }
78 { v+ { +vector+ +vector+ -> +vector+ } }
79 { vs+ { +vector+ +vector+ -> +vector+ } }
80 { v+- { +vector+ +vector+ -> +vector+ } }
81 { v+n { +vector+ +scalar+ -> +vector+ } }
82 { v- { +vector+ +vector+ -> +vector+ } }
83 { vneg { +vector+ -> +vector+ } }
84 { vs- { +vector+ +vector+ -> +vector+ } }
85 { v-n { +vector+ +scalar+ -> +vector+ } }
86 { v. { +vector+ +vector+ -> +scalar+ } }
87 { v/ { +vector+ +vector+ -> +vector+ } }
88 { v/n { +vector+ +scalar+ -> +vector+ } }
89 { vceiling { +vector+ -> +vector+ } }
90 { vfloor { +vector+ -> +vector+ } }
91 { vmax { +vector+ +vector+ -> +vector+ } }
92 { vmin { +vector+ +vector+ -> +vector+ } }
93 { vneg { +vector+ -> +vector+ } }
94 { vtruncate { +vector+ -> +vector+ } }
95 { sum { +vector+ -> +scalar+ } }
96 { vabs { +vector+ -> +vector+ } }
97 { vsqrt { +vector+ -> +vector+ } }
98 { vbitand { +vector+ +vector+ -> +vector+ } }
99 { vbitandn { +vector+ +vector+ -> +vector+ } }
100 { vbitor { +vector+ +vector+ -> +vector+ } }
101 { vbitxor { +vector+ +vector+ -> +vector+ } }
102 { vbitnot { +vector+ -> +vector+ } }
103 { vand { +vector+ +vector+ -> +vector+ } }
104 { vandn { +vector+ +vector+ -> +vector+ } }
105 { vor { +vector+ +vector+ -> +vector+ } }
106 { vxor { +vector+ +vector+ -> +vector+ } }
107 { vnot { +vector+ -> +vector+ } }
108 { vlshift { +vector+ +scalar+ -> +vector+ } }
109 { vrshift { +vector+ +scalar+ -> +vector+ } }
110 { hlshift { +vector+ +literal+ -> +vector+ } }
111 { hrshift { +vector+ +literal+ -> +vector+ } }
112 { vshuffle-elements { +vector+ +literal+ -> +vector+ } }
113 { vshuffle-bytes { +vector+ +any-vector+ -> +vector+ } }
114 { vbroadcast { +vector+ +literal+ -> +vector+ } }
115 { (vmerge-head) { +vector+ +vector+ -> +vector+ } }
116 { (vmerge-tail) { +vector+ +vector+ -> +vector+ } }
117 { (v>float) { +vector+ +literal+ -> +vector+ } }
118 { (v>integer) { +vector+ +literal+ -> +vector+ } }
119 { (vpack-signed) { +vector+ +vector+ +literal+ -> +vector+ } }
120 { (vpack-unsigned) { +vector+ +vector+ +literal+ -> +vector+ } }
121 { (vunpack-head) { +vector+ +literal+ -> +vector+ } }
122 { (vunpack-tail) { +vector+ +literal+ -> +vector+ } }
123 { v<= { +vector+ +vector+ -> +vector+ } }
124 { v< { +vector+ +vector+ -> +vector+ } }
125 { v= { +vector+ +vector+ -> +vector+ } }
126 { v> { +vector+ +vector+ -> +vector+ } }
127 { v>= { +vector+ +vector+ -> +vector+ } }
128 { vunordered? { +vector+ +vector+ -> +vector+ } }
129 { vany? { +vector+ -> +boolean+ } }
130 { vall? { +vector+ -> +boolean+ } }
131 { vnone? { +vector+ -> +boolean+ } }
134 PREDICATE: vector-word < word vector-words key? ;
136 : specializations ( word -- assoc )
137 dup "specializations" word-prop
138 [ ] [ V{ } clone [ "specializations" set-word-prop ] keep ] ?if ;
140 M: vector-word subwords specializations values [ word? ] filter ;
142 : add-specialization ( new-word signature word -- )
143 specializations set-at ;
145 ERROR: bad-vector-word word ;
147 : word-schema ( word -- schema )
148 vector-words ?at [ bad-vector-word ] unless ;
150 : inputs ( schema -- seq ) { -> } split first ;
152 : outputs ( schema -- seq ) { -> } split second ;
154 : loop-vector-op ( word array-type elt-type -- word' )
156 [ inputs (specialize-vector-word) ]
157 [ outputs record-output-signature ] 3bi ;
159 :: specialize-vector-word ( word array-type elt-type simd -- word/quot' )
160 word simd key? [ word simd at ] [ word array-type elt-type loop-vector-op ] if ;
162 :: input-signature ( word array-type elt-type -- signature )
163 array-type elt-type word word-schema inputs signature-for-schema ;
165 : vector-words-for-type ( elt-type -- words )
167 ! Can't do shifts on floats
168 { [ dup float class<= ] [ vector-words keys { vlshift vrshift } diff ] }
169 ! Can't divide integers
170 { [ dup integer class<= ] [ vector-words keys { vsqrt n/v v/n v/ normalize } diff ] }
171 ! Can't compute square root of complex numbers (vsqrt uses fsqrt not sqrt)
172 { [ dup complex class<= ] [ vector-words keys { vsqrt } diff ] }
175 ! Don't specialize horizontal shifts, shuffles, and conversions at all, they're only for SIMD
177 hlshift hrshift vshuffle-elements vshuffle-bytes vbroadcast
178 (v>integer) (v>float)
179 (vpack-signed) (vpack-unsigned)
180 (vunpack-head) (vunpack-tail)
184 :: specialize-vector-words ( array-type elt-type simd -- )
185 elt-type vector-words-for-type simd keys union [
186 [ array-type elt-type simd specialize-vector-word ]
187 [ array-type elt-type input-signature ]
189 tri add-specialization
192 : specialization-matches? ( value-infos signature -- ? )
193 [ [ [ class>> ] dip class<= ] [ literal?>> ] if* ] 2all? ;
195 : find-specialization ( classes word -- word/f )
197 [ first specialization-matches? ] with find
198 swap [ second ] when ;
200 : vector-word-custom-inlining ( #call -- word/f )
201 [ in-d>> [ value-info ] map ] [ word>> ] bi
202 find-specialization ;
205 [ vector-word-custom-inlining ]
206 "custom-inlining" set-word-prop