1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package math
6
7import "math/bits"
8
9func zero(x uint64) uint64 {
10	if x == 0 {
11		return 1
12	}
13	return 0
14	// branchless:
15	// return ((x>>1 | x&1) - 1) >> 63
16}
17
18func nonzero(x uint64) uint64 {
19	if x != 0 {
20		return 1
21	}
22	return 0
23	// branchless:
24	// return 1 - ((x>>1|x&1)-1)>>63
25}
26
27func shl(u1, u2 uint64, n uint) (r1, r2 uint64) {
28	r1 = u1<<n | u2>>(64-n) | u2<<(n-64)
29	r2 = u2 << n
30	return
31}
32
33func shr(u1, u2 uint64, n uint) (r1, r2 uint64) {
34	r2 = u2>>n | u1<<(64-n) | u1>>(n-64)
35	r1 = u1 >> n
36	return
37}
38
39// shrcompress compresses the bottom n+1 bits of the two-word
40// value into a single bit. the result is equal to the value
41// shifted to the right by n, except the result's 0th bit is
42// set to the bitwise OR of the bottom n+1 bits.
43func shrcompress(u1, u2 uint64, n uint) (r1, r2 uint64) {
44	// TODO: Performance here is really sensitive to the
45	// order/placement of these branches. n == 0 is common
46	// enough to be in the fast path. Perhaps more measurement
47	// needs to be done to find the optimal order/placement?
48	switch {
49	case n == 0:
50		return u1, u2
51	case n == 64:
52		return 0, u1 | nonzero(u2)
53	case n >= 128:
54		return 0, nonzero(u1 | u2)
55	case n < 64:
56		r1, r2 = shr(u1, u2, n)
57		r2 |= nonzero(u2 & (1<<n - 1))
58	case n < 128:
59		r1, r2 = shr(u1, u2, n)
60		r2 |= nonzero(u1&(1<<(n-64)-1) | u2)
61	}
62	return
63}
64
65func lz(u1, u2 uint64) (l int32) {
66	l = int32(bits.LeadingZeros64(u1))
67	if l == 64 {
68		l += int32(bits.LeadingZeros64(u2))
69	}
70	return l
71}
72
73// split splits b into sign, biased exponent, and mantissa.
74// It adds the implicit 1 bit to the mantissa for normal values,
75// and normalizes subnormal values.
76func split(b uint64) (sign uint32, exp int32, mantissa uint64) {
77	sign = uint32(b >> 63)
78	exp = int32(b>>52) & mask
79	mantissa = b & fracMask
80
81	if exp == 0 {
82		// Normalize value if subnormal.
83		shift := uint(bits.LeadingZeros64(mantissa) - 11)
84		mantissa <<= shift
85		exp = 1 - int32(shift)
86	} else {
87		// Add implicit 1 bit
88		mantissa |= 1 << 52
89	}
90	return
91}
92
93// FMA returns x * y + z, computed with only one rounding.
94// (That is, FMA returns the fused multiply-add of x, y, and z.)
95func FMA(x, y, z float64) float64 {
96	bx, by, bz := Float64bits(x), Float64bits(y), Float64bits(z)
97
98	// Inf or NaN or zero involved. At most one rounding will occur.
99	if x == 0.0 || y == 0.0 || z == 0.0 || bx&uvinf == uvinf || by&uvinf == uvinf {
100		return x*y + z
101	}
102	// Handle non-finite z separately. Evaluating x*y+z where
103	// x and y are finite, but z is infinite, should always result in z.
104	if bz&uvinf == uvinf {
105		return z
106	}
107
108	// Inputs are (sub)normal.
109	// Split x, y, z into sign, exponent, mantissa.
110	xs, xe, xm := split(bx)
111	ys, ye, ym := split(by)
112	zs, ze, zm := split(bz)
113
114	// Compute product p = x*y as sign, exponent, two-word mantissa.
115	// Start with exponent. "is normal" bit isn't subtracted yet.
116	pe := xe + ye - bias + 1
117
118	// pm1:pm2 is the double-word mantissa for the product p.
119	// Shift left to leave top bit in product. Effectively
120	// shifts the 106-bit product to the left by 21.
121	pm1, pm2 := bits.Mul64(xm<<10, ym<<11)
122	zm1, zm2 := zm<<10, uint64(0)
123	ps := xs ^ ys // product sign
124
125	// normalize to 62nd bit
126	is62zero := uint((^pm1 >> 62) & 1)
127	pm1, pm2 = shl(pm1, pm2, is62zero)
128	pe -= int32(is62zero)
129
130	// Swap addition operands so |p| >= |z|
131	if pe < ze || pe == ze && pm1 < zm1 {
132		ps, pe, pm1, pm2, zs, ze, zm1, zm2 = zs, ze, zm1, zm2, ps, pe, pm1, pm2
133	}
134
135	// Special case: if p == -z the result is always +0 since neither operand is zero.
136	if ps != zs && pe == ze && pm1 == zm1 && pm2 == zm2 {
137		return 0
138	}
139
140	// Align significands
141	zm1, zm2 = shrcompress(zm1, zm2, uint(pe-ze))
142
143	// Compute resulting significands, normalizing if necessary.
144	var m, c uint64
145	if ps == zs {
146		// Adding (pm1:pm2) + (zm1:zm2)
147		pm2, c = bits.Add64(pm2, zm2, 0)
148		pm1, _ = bits.Add64(pm1, zm1, c)
149		pe -= int32(^pm1 >> 63)
150		pm1, m = shrcompress(pm1, pm2, uint(64+pm1>>63))
151	} else {
152		// Subtracting (pm1:pm2) - (zm1:zm2)
153		// TODO: should we special-case cancellation?
154		pm2, c = bits.Sub64(pm2, zm2, 0)
155		pm1, _ = bits.Sub64(pm1, zm1, c)
156		nz := lz(pm1, pm2)
157		pe -= nz
158		m, pm2 = shl(pm1, pm2, uint(nz-1))
159		m |= nonzero(pm2)
160	}
161
162	// Round and break ties to even
163	if pe > 1022+bias || pe == 1022+bias && (m+1<<9)>>63 == 1 {
164		// rounded value overflows exponent range
165		return Float64frombits(uint64(ps)<<63 | uvinf)
166	}
167	if pe < 0 {
168		n := uint(-pe)
169		m = m>>n | nonzero(m&(1<<n-1))
170		pe = 0
171	}
172	m = ((m + 1<<9) >> 10) & ^zero((m&(1<<10-1))^1<<9)
173	pe &= -int32(nonzero(m))
174	return Float64frombits(uint64(ps)<<63 + uint64(pe)<<52 + m)
175}
176