1 ! Copyright (C) 2009 Slava Pestov.
2 ! See http://factorcode.org/license.txt for BSD license.
3 USING: words kernel make sequences effects sets kernel.private
4 accessors combinators math math.intervals math.vectors
5 math.vectors.conversion.backend
6 namespaces assocs fry splitting classes.algebra generalizations
7 locals compiler.tree.propagation.info ;
8 IN: math.vectors.specialization
10 SYMBOLS: -> +vector+ +scalar+ +boolean+ +nonnegative+ +literal+ ;
12 : signature-for-schema ( array-type elt-type schema -- signature )
17 { +boolean+ [ 2drop boolean ] }
18 { +nonnegative+ [ nip ] }
19 { +literal+ [ 2drop f ] }
23 : (specialize-vector-word) ( word array-type elt-type schema -- word' )
25 [ [ name>> ] [ [ name>> ] map "," join ] bi* "=>" glue f <word> ]
26 [ [ , \ declare , def>> % ] [ ] make ]
29 [ define-declared ] [ 2drop ] 3bi ;
31 : output-infos ( array-type elt-type schema -- value-infos )
34 { +vector+ [ drop <class-info> ] }
35 { +scalar+ [ nip <class-info> ] }
36 { +boolean+ [ 2drop boolean <class-info> ] }
41 dup complex class<= [ drop float ] when
42 [0,inf] <class/interval-info>
48 : record-output-signature ( word array-type elt-type schema -- word )
52 [ [ stack-effect in>> length '[ _ ndrop ] ] dip append ] 2tri
53 "outputs" set-word-prop ;
55 CONSTANT: vector-words
57 { [v-] { +vector+ +vector+ -> +vector+ } }
58 { distance { +vector+ +vector+ -> +nonnegative+ } }
59 { n*v { +scalar+ +vector+ -> +vector+ } }
60 { n+v { +scalar+ +vector+ -> +vector+ } }
61 { n-v { +scalar+ +vector+ -> +vector+ } }
62 { n/v { +scalar+ +vector+ -> +vector+ } }
63 { norm { +vector+ -> +nonnegative+ } }
64 { norm-sq { +vector+ -> +nonnegative+ } }
65 { normalize { +vector+ -> +vector+ } }
66 { v* { +vector+ +vector+ -> +vector+ } }
67 { vs* { +vector+ +vector+ -> +vector+ } }
68 { v*n { +vector+ +scalar+ -> +vector+ } }
69 { v+ { +vector+ +vector+ -> +vector+ } }
70 { vs+ { +vector+ +vector+ -> +vector+ } }
71 { v+- { +vector+ +vector+ -> +vector+ } }
72 { v+n { +vector+ +scalar+ -> +vector+ } }
73 { v- { +vector+ +vector+ -> +vector+ } }
74 { vneg { +vector+ -> +vector+ } }
75 { vs- { +vector+ +vector+ -> +vector+ } }
76 { v-n { +vector+ +scalar+ -> +vector+ } }
77 { v. { +vector+ +vector+ -> +scalar+ } }
78 { v/ { +vector+ +vector+ -> +vector+ } }
79 { v/n { +vector+ +scalar+ -> +vector+ } }
80 { vceiling { +vector+ -> +vector+ } }
81 { vfloor { +vector+ -> +vector+ } }
82 { vmax { +vector+ +vector+ -> +vector+ } }
83 { vmin { +vector+ +vector+ -> +vector+ } }
84 { vneg { +vector+ -> +vector+ } }
85 { vtruncate { +vector+ -> +vector+ } }
86 { sum { +vector+ -> +scalar+ } }
87 { vabs { +vector+ -> +vector+ } }
88 { vsqrt { +vector+ -> +vector+ } }
89 { vbitand { +vector+ +vector+ -> +vector+ } }
90 { vbitandn { +vector+ +vector+ -> +vector+ } }
91 { vbitor { +vector+ +vector+ -> +vector+ } }
92 { vbitxor { +vector+ +vector+ -> +vector+ } }
93 { vbitnot { +vector+ -> +vector+ } }
94 { vand { +vector+ +vector+ -> +vector+ } }
95 { vandn { +vector+ +vector+ -> +vector+ } }
96 { vor { +vector+ +vector+ -> +vector+ } }
97 { vxor { +vector+ +vector+ -> +vector+ } }
98 { vnot { +vector+ -> +vector+ } }
99 { vlshift { +vector+ +scalar+ -> +vector+ } }
100 { vrshift { +vector+ +scalar+ -> +vector+ } }
101 { hlshift { +vector+ +literal+ -> +vector+ } }
102 { hrshift { +vector+ +literal+ -> +vector+ } }
103 { vshuffle-elements { +vector+ +literal+ -> +vector+ } }
104 { vshuffle-bytes { +vector+ +vector+ -> +vector+ } }
105 { vbroadcast { +vector+ +literal+ -> +vector+ } }
106 { (vmerge-head) { +vector+ +vector+ -> +vector+ } }
107 { (vmerge-tail) { +vector+ +vector+ -> +vector+ } }
108 { (v>float) { +vector+ +literal+ -> +vector+ } }
109 { (v>integer) { +vector+ +literal+ -> +vector+ } }
110 { (vpack-signed) { +vector+ +vector+ +literal+ -> +vector+ } }
111 { (vpack-unsigned) { +vector+ +vector+ +literal+ -> +vector+ } }
112 { (vunpack-head) { +vector+ +literal+ -> +vector+ } }
113 { (vunpack-tail) { +vector+ +literal+ -> +vector+ } }
114 { v<= { +vector+ +vector+ -> +vector+ } }
115 { v< { +vector+ +vector+ -> +vector+ } }
116 { v= { +vector+ +vector+ -> +vector+ } }
117 { v> { +vector+ +vector+ -> +vector+ } }
118 { v>= { +vector+ +vector+ -> +vector+ } }
119 { vunordered? { +vector+ +vector+ -> +vector+ } }
120 { vany? { +vector+ -> +boolean+ } }
121 { vall? { +vector+ -> +boolean+ } }
122 { vnone? { +vector+ -> +boolean+ } }
125 PREDICATE: vector-word < word vector-words key? ;
127 : specializations ( word -- assoc )
128 dup "specializations" word-prop
129 [ ] [ V{ } clone [ "specializations" set-word-prop ] keep ] ?if ;
131 M: vector-word subwords specializations values [ word? ] filter ;
133 : add-specialization ( new-word signature word -- )
134 specializations set-at ;
136 ERROR: bad-vector-word word ;
138 : word-schema ( word -- schema )
139 vector-words ?at [ bad-vector-word ] unless ;
141 : inputs ( schema -- seq ) { -> } split first ;
143 : outputs ( schema -- seq ) { -> } split second ;
145 : loop-vector-op ( word array-type elt-type -- word' )
147 [ inputs (specialize-vector-word) ]
148 [ outputs record-output-signature ] 3bi ;
150 :: specialize-vector-word ( word array-type elt-type simd -- word/quot' )
151 word simd key? [ word simd at ] [ word array-type elt-type loop-vector-op ] if ;
153 :: input-signature ( word array-type elt-type -- signature )
154 array-type elt-type word word-schema inputs signature-for-schema ;
156 : vector-words-for-type ( elt-type -- words )
158 ! Can't do shifts on floats
159 { [ dup float class<= ] [ vector-words keys { vlshift vrshift } diff ] }
160 ! Can't divide integers
161 { [ dup integer class<= ] [ vector-words keys { vsqrt n/v v/n v/ normalize } diff ] }
162 ! Can't compute square root of complex numbers (vsqrt uses fsqrt not sqrt)
163 { [ dup complex class<= ] [ vector-words keys { vsqrt } diff ] }
166 ! Don't specialize horizontal shifts, shuffles, and conversions at all, they're only for SIMD
168 hlshift hrshift vshuffle-elements vshuffle-bytes vbroadcast
169 (v>integer) (v>float)
170 (vpack-signed) (vpack-unsigned)
171 (vunpack-head) (vunpack-tail)
175 :: specialize-vector-words ( array-type elt-type simd -- )
176 elt-type vector-words-for-type simd keys union [
177 [ array-type elt-type simd specialize-vector-word ]
178 [ array-type elt-type input-signature ]
180 tri add-specialization
183 : specialization-matches? ( value-infos signature -- ? )
184 [ [ [ class>> ] dip class<= ] [ literal?>> ] if* ] 2all? ;
186 : find-specialization ( classes word -- word/f )
188 [ first specialization-matches? ] with find
189 swap [ second ] when ;
191 : vector-word-custom-inlining ( #call -- word/f )
192 [ in-d>> [ value-info ] map ] [ word>> ] bi
193 find-specialization ;
196 [ vector-word-custom-inlining ]
197 "custom-inlining" set-word-prop