You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1107 lines
28 KiB
1107 lines
28 KiB
// Package decimal implements an arbitrary precision fixed-point decimal. |
|
// |
|
// To use as part of a struct: |
|
// |
|
// type Struct struct { |
|
// Number Decimal |
|
// } |
|
// |
|
// The zero-value of a Decimal is 0, as you would expect. |
|
// |
|
// The best way to create a new Decimal is to use decimal.NewFromString, ex: |
|
// |
|
// n, err := decimal.NewFromString("-123.4567") |
|
// n.String() // output: "-123.4567" |
|
// |
|
// NOTE: This can "only" represent numbers with a maximum of 2^31 digits |
|
// after the decimal point. |
|
package decimal |
|
|
|
import ( |
|
"database/sql/driver" |
|
"encoding/binary" |
|
"fmt" |
|
"math" |
|
"math/big" |
|
"strconv" |
|
"strings" |
|
) |
|
|
|
// DivisionPrecision is the number of decimal places in the result when it |
|
// doesn't divide exactly. |
|
// |
|
// Example: |
|
// |
|
// d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3) |
|
// d1.String() // output: "0.6666666666666667" |
|
// d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000) |
|
// d2.String() // output: "0.0000666666666667" |
|
// d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3) |
|
// d3.String() // output: "6666.6666666666666667" |
|
// decimal.DivisionPrecision = 3 |
|
// d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3) |
|
// d4.String() // output: "0.667" |
|
// |
|
var DivisionPrecision = 16 |
|
|
|
// MarshalJSONWithoutQuotes should be set to true if you want the decimal to |
|
// be JSON marshaled as a number, instead of as a string. |
|
// WARNING: this is dangerous for decimals with many digits, since many JSON |
|
// unmarshallers (ex: Javascript's) will unmarshal JSON numbers to IEEE 754 |
|
// double-precision floating point numbers, which means you can potentially |
|
// silently lose precision. |
|
var MarshalJSONWithoutQuotes = false |
|
|
|
// Zero constant, to make computations faster. |
|
var Zero = New(0, 1) |
|
|
|
// fiveDec used in Cash Rounding |
|
var fiveDec = New(5, 0) |
|
|
|
var zeroInt = big.NewInt(0) |
|
var oneInt = big.NewInt(1) |
|
var twoInt = big.NewInt(2) |
|
var fourInt = big.NewInt(4) |
|
var fiveInt = big.NewInt(5) |
|
var tenInt = big.NewInt(10) |
|
var twentyInt = big.NewInt(20) |
|
|
|
// Decimal represents a fixed-point decimal. It is immutable. |
|
// number = value * 10 ^ exp |
|
type Decimal struct { |
|
value *big.Int |
|
|
|
// NOTE(vadim): this must be an int32, because we cast it to float64 during |
|
// calculations. If exp is 64 bit, we might lose precision. |
|
// If we cared about being able to represent every possible decimal, we |
|
// could make exp a *big.Int but it would hurt performance and numbers |
|
// like that are unrealistic. |
|
exp int32 |
|
} |
|
|
|
// New returns a new fixed-point decimal, value * 10 ^ exp. |
|
func New(value int64, exp int32) Decimal { |
|
return Decimal{ |
|
value: big.NewInt(value), |
|
exp: exp, |
|
} |
|
} |
|
|
|
// NewFromBigInt returns a new Decimal from a big.Int, value * 10 ^ exp |
|
func NewFromBigInt(value *big.Int, exp int32) Decimal { |
|
return Decimal{ |
|
value: big.NewInt(0).Set(value), |
|
exp: exp, |
|
} |
|
} |
|
|
|
// NewFromString returns a new Decimal from a string representation. |
|
// |
|
// Example: |
|
// |
|
// d, err := NewFromString("-123.45") |
|
// d2, err := NewFromString(".0001") |
|
// |
|
func NewFromString(value string) (Decimal, error) { |
|
originalInput := value |
|
var intString string |
|
var exp int64 |
|
|
|
// Check if number is using scientific notation |
|
eIndex := strings.IndexAny(value, "Ee") |
|
if eIndex != -1 { |
|
expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32) |
|
if err != nil { |
|
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { |
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value) |
|
} |
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: exponent is not numeric", value) |
|
} |
|
value = value[:eIndex] |
|
exp = expInt |
|
} |
|
|
|
parts := strings.Split(value, ".") |
|
if len(parts) == 1 { |
|
// There is no decimal point, we can just parse the original string as |
|
// an int |
|
intString = value |
|
} else if len(parts) == 2 { |
|
// strip the insignificant digits for more accurate comparisons. |
|
decimalPart := strings.TrimRight(parts[1], "0") |
|
intString = parts[0] + decimalPart |
|
expInt := -len(decimalPart) |
|
exp += int64(expInt) |
|
} else { |
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) |
|
} |
|
|
|
dValue := new(big.Int) |
|
_, ok := dValue.SetString(intString, 10) |
|
if !ok { |
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal", value) |
|
} |
|
|
|
if exp < math.MinInt32 || exp > math.MaxInt32 { |
|
// NOTE(vadim): I doubt a string could realistically be this long |
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", originalInput) |
|
} |
|
|
|
return Decimal{ |
|
value: dValue, |
|
exp: int32(exp), |
|
}, nil |
|
} |
|
|
|
// RequireFromString returns a new Decimal from a string representation |
|
// or panics if NewFromString would have returned an error. |
|
// |
|
// Example: |
|
// |
|
// d := RequireFromString("-123.45") |
|
// d2 := RequireFromString(".0001") |
|
// |
|
func RequireFromString(value string) Decimal { |
|
dec, err := NewFromString(value) |
|
if err != nil { |
|
panic(err) |
|
} |
|
return dec |
|
} |
|
|
|
// NewFromFloat converts a float64 to Decimal. |
|
// |
|
// Example: |
|
// |
|
// NewFromFloat(123.45678901234567).String() // output: "123.4567890123456" |
|
// NewFromFloat(.00000000000000001).String() // output: "0.00000000000000001" |
|
// |
|
// NOTE: some float64 numbers can take up about 300 bytes of memory in decimal representation. |
|
// Consider using NewFromFloatWithExponent if space is more important than precision. |
|
// |
|
// NOTE: this will panic on NaN, +/-inf |
|
func NewFromFloat(value float64) Decimal { |
|
return NewFromFloatWithExponent(value, math.MinInt32) |
|
} |
|
|
|
// NewFromFloatWithExponent converts a float64 to Decimal, with an arbitrary |
|
// number of fractional digits. |
|
// |
|
// Example: |
|
// |
|
// NewFromFloatWithExponent(123.456, -2).String() // output: "123.46" |
|
// |
|
func NewFromFloatWithExponent(value float64, exp int32) Decimal { |
|
if math.IsNaN(value) || math.IsInf(value, 0) { |
|
panic(fmt.Sprintf("Cannot create a Decimal from %v", value)) |
|
} |
|
|
|
bits := math.Float64bits(value) |
|
mant := bits & (1<<52 - 1) |
|
exp2 := int32((bits >> 52) & (1<<11 - 1)) |
|
sign := bits >> 63 |
|
|
|
if exp2 == 0 { |
|
// specials |
|
if mant == 0 { |
|
return Decimal{} |
|
} else { |
|
// subnormal |
|
exp2++ |
|
} |
|
} else { |
|
// normal |
|
mant |= 1 << 52 |
|
} |
|
|
|
exp2 -= 1023 + 52 |
|
|
|
// normalizing base-2 values |
|
for mant&1 == 0 { |
|
mant = mant >> 1 |
|
exp2++ |
|
} |
|
|
|
// maximum number of fractional base-10 digits to represent 2^N exactly cannot be more than -N if N<0 |
|
if exp < 0 && exp < exp2 { |
|
if exp2 < 0 { |
|
exp = exp2 |
|
} else { |
|
exp = 0 |
|
} |
|
} |
|
|
|
// representing 10^M * 2^N as 5^M * 2^(M+N) |
|
exp2 -= exp |
|
|
|
temp := big.NewInt(1) |
|
dMant := big.NewInt(int64(mant)) |
|
|
|
// applying 5^M |
|
if exp > 0 { |
|
temp = temp.SetInt64(int64(exp)) |
|
temp = temp.Exp(fiveInt, temp, nil) |
|
} else if exp < 0 { |
|
temp = temp.SetInt64(-int64(exp)) |
|
temp = temp.Exp(fiveInt, temp, nil) |
|
dMant = dMant.Mul(dMant, temp) |
|
temp = temp.SetUint64(1) |
|
} |
|
|
|
// applying 2^(M+N) |
|
if exp2 > 0 { |
|
dMant = dMant.Lsh(dMant, uint(exp2)) |
|
} else if exp2 < 0 { |
|
temp = temp.Lsh(temp, uint(-exp2)) |
|
} |
|
|
|
// rounding and downscaling |
|
if exp > 0 || exp2 < 0 { |
|
halfDown := new(big.Int).Rsh(temp, 1) |
|
dMant = dMant.Add(dMant, halfDown) |
|
dMant = dMant.Quo(dMant, temp) |
|
} |
|
|
|
if sign == 1 { |
|
dMant = dMant.Neg(dMant) |
|
} |
|
|
|
return Decimal{ |
|
value: dMant, |
|
exp: exp, |
|
} |
|
} |
|
|
|
// rescale returns a rescaled version of the decimal. Returned |
|
// decimal may be less precise if the given exponent is bigger |
|
// than the initial exponent of the Decimal. |
|
// NOTE: this will truncate, NOT round |
|
// |
|
// Example: |
|
// |
|
// d := New(12345, -4) |
|
// d2 := d.rescale(-1) |
|
// d3 := d2.rescale(-4) |
|
// println(d1) |
|
// println(d2) |
|
// println(d3) |
|
// |
|
// Output: |
|
// |
|
// 1.2345 |
|
// 1.2 |
|
// 1.2000 |
|
// |
|
func (d Decimal) rescale(exp int32) Decimal { |
|
d.ensureInitialized() |
|
// NOTE(vadim): must convert exps to float64 before - to prevent overflow |
|
diff := math.Abs(float64(exp) - float64(d.exp)) |
|
value := new(big.Int).Set(d.value) |
|
|
|
expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil) |
|
if exp > d.exp { |
|
value = value.Quo(value, expScale) |
|
} else if exp < d.exp { |
|
value = value.Mul(value, expScale) |
|
} |
|
|
|
return Decimal{ |
|
value: value, |
|
exp: exp, |
|
} |
|
} |
|
|
|
// Abs returns the absolute value of the decimal. |
|
func (d Decimal) Abs() Decimal { |
|
d.ensureInitialized() |
|
d2Value := new(big.Int).Abs(d.value) |
|
return Decimal{ |
|
value: d2Value, |
|
exp: d.exp, |
|
} |
|
} |
|
|
|
// Add returns d + d2. |
|
func (d Decimal) Add(d2 Decimal) Decimal { |
|
baseScale := min(d.exp, d2.exp) |
|
rd := d.rescale(baseScale) |
|
rd2 := d2.rescale(baseScale) |
|
|
|
d3Value := new(big.Int).Add(rd.value, rd2.value) |
|
return Decimal{ |
|
value: d3Value, |
|
exp: baseScale, |
|
} |
|
} |
|
|
|
// Sub returns d - d2. |
|
func (d Decimal) Sub(d2 Decimal) Decimal { |
|
baseScale := min(d.exp, d2.exp) |
|
rd := d.rescale(baseScale) |
|
rd2 := d2.rescale(baseScale) |
|
|
|
d3Value := new(big.Int).Sub(rd.value, rd2.value) |
|
return Decimal{ |
|
value: d3Value, |
|
exp: baseScale, |
|
} |
|
} |
|
|
|
// Neg returns -d. |
|
func (d Decimal) Neg() Decimal { |
|
d.ensureInitialized() |
|
val := new(big.Int).Neg(d.value) |
|
return Decimal{ |
|
value: val, |
|
exp: d.exp, |
|
} |
|
} |
|
|
|
// Mul returns d * d2. |
|
func (d Decimal) Mul(d2 Decimal) Decimal { |
|
d.ensureInitialized() |
|
d2.ensureInitialized() |
|
|
|
expInt64 := int64(d.exp) + int64(d2.exp) |
|
if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 { |
|
// NOTE(vadim): better to panic than give incorrect results, as |
|
// Decimals are usually used for money |
|
panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64)) |
|
} |
|
|
|
d3Value := new(big.Int).Mul(d.value, d2.value) |
|
return Decimal{ |
|
value: d3Value, |
|
exp: int32(expInt64), |
|
} |
|
} |
|
|
|
// Div returns d / d2. If it doesn't divide exactly, the result will have |
|
// DivisionPrecision digits after the decimal point. |
|
func (d Decimal) Div(d2 Decimal) Decimal { |
|
return d.DivRound(d2, int32(DivisionPrecision)) |
|
} |
|
|
|
// QuoRem does divsion with remainder |
|
// d.QuoRem(d2,precision) returns quotient q and remainder r such that |
|
// d = d2 * q + r, q an integer multiple of 10^(-precision) |
|
// 0 <= r < abs(d2) * 10 ^(-precision) if d>=0 |
|
// 0 >= r > -abs(d2) * 10 ^(-precision) if d<0 |
|
// Note that precision<0 is allowed as input. |
|
func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) { |
|
d.ensureInitialized() |
|
d2.ensureInitialized() |
|
if d2.value.Sign() == 0 { |
|
panic("decimal division by 0") |
|
} |
|
scale := -precision |
|
e := int64(d.exp - d2.exp - scale) |
|
if e > math.MaxInt32 || e < math.MinInt32 { |
|
panic("overflow in decimal QuoRem") |
|
} |
|
var aa, bb, expo big.Int |
|
var scalerest int32 |
|
// d = a 10^ea |
|
// d2 = b 10^eb |
|
if e < 0 { |
|
aa = *d.value |
|
expo.SetInt64(-e) |
|
bb.Exp(tenInt, &expo, nil) |
|
bb.Mul(d2.value, &bb) |
|
scalerest = d.exp |
|
// now aa = a |
|
// bb = b 10^(scale + eb - ea) |
|
} else { |
|
expo.SetInt64(e) |
|
aa.Exp(tenInt, &expo, nil) |
|
aa.Mul(d.value, &aa) |
|
bb = *d2.value |
|
scalerest = scale + d2.exp |
|
// now aa = a ^ (ea - eb - scale) |
|
// bb = b |
|
} |
|
var q, r big.Int |
|
q.QuoRem(&aa, &bb, &r) |
|
dq := Decimal{value: &q, exp: scale} |
|
dr := Decimal{value: &r, exp: scalerest} |
|
return dq, dr |
|
} |
|
|
|
// DivRound divides and rounds to a given precision |
|
// i.e. to an integer multiple of 10^(-precision) |
|
// for a positive quotient digit 5 is rounded up, away from 0 |
|
// if the quotient is negative then digit 5 is rounded down, away from 0 |
|
// Note that precision<0 is allowed as input. |
|
func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal { |
|
// QuoRem already checks initialization |
|
q, r := d.QuoRem(d2, precision) |
|
// the actual rounding decision is based on comparing r*10^precision and d2/2 |
|
// instead compare 2 r 10 ^precision and d2 |
|
var rv2 big.Int |
|
rv2.Abs(r.value) |
|
rv2.Lsh(&rv2, 1) |
|
// now rv2 = abs(r.value) * 2 |
|
r2 := Decimal{value: &rv2, exp: r.exp + precision} |
|
// r2 is now 2 * r * 10 ^ precision |
|
var c = r2.Cmp(d2.Abs()) |
|
|
|
if c < 0 { |
|
return q |
|
} |
|
|
|
if d.value.Sign()*d2.value.Sign() < 0 { |
|
return q.Sub(New(1, -precision)) |
|
} |
|
|
|
return q.Add(New(1, -precision)) |
|
} |
|
|
|
// Mod returns d % d2. |
|
func (d Decimal) Mod(d2 Decimal) Decimal { |
|
quo := d.Div(d2).Truncate(0) |
|
return d.Sub(d2.Mul(quo)) |
|
} |
|
|
|
// Pow returns d to the power d2 |
|
func (d Decimal) Pow(d2 Decimal) Decimal { |
|
var temp Decimal |
|
if d2.IntPart() == 0 { |
|
return NewFromFloat(1) |
|
} |
|
temp = d.Pow(d2.Div(NewFromFloat(2))) |
|
if d2.IntPart()%2 == 0 { |
|
return temp.Mul(temp) |
|
} |
|
if d2.IntPart() > 0 { |
|
return temp.Mul(temp).Mul(d) |
|
} |
|
return temp.Mul(temp).Div(d) |
|
} |
|
|
|
// Cmp compares the numbers represented by d and d2 and returns: |
|
// |
|
// -1 if d < d2 |
|
// 0 if d == d2 |
|
// +1 if d > d2 |
|
// |
|
func (d Decimal) Cmp(d2 Decimal) int { |
|
d.ensureInitialized() |
|
d2.ensureInitialized() |
|
|
|
if d.exp == d2.exp { |
|
return d.value.Cmp(d2.value) |
|
} |
|
|
|
baseExp := min(d.exp, d2.exp) |
|
rd := d.rescale(baseExp) |
|
rd2 := d2.rescale(baseExp) |
|
|
|
return rd.value.Cmp(rd2.value) |
|
} |
|
|
|
// Equal returns whether the numbers represented by d and d2 are equal. |
|
func (d Decimal) Equal(d2 Decimal) bool { |
|
return d.Cmp(d2) == 0 |
|
} |
|
|
|
// Equals is deprecated, please use Equal method instead |
|
func (d Decimal) Equals(d2 Decimal) bool { |
|
return d.Equal(d2) |
|
} |
|
|
|
// GreaterThan (GT) returns true when d is greater than d2. |
|
func (d Decimal) GreaterThan(d2 Decimal) bool { |
|
return d.Cmp(d2) == 1 |
|
} |
|
|
|
// GreaterThanOrEqual (GTE) returns true when d is greater than or equal to d2. |
|
func (d Decimal) GreaterThanOrEqual(d2 Decimal) bool { |
|
cmp := d.Cmp(d2) |
|
return cmp == 1 || cmp == 0 |
|
} |
|
|
|
// LessThan (LT) returns true when d is less than d2. |
|
func (d Decimal) LessThan(d2 Decimal) bool { |
|
return d.Cmp(d2) == -1 |
|
} |
|
|
|
// LessThanOrEqual (LTE) returns true when d is less than or equal to d2. |
|
func (d Decimal) LessThanOrEqual(d2 Decimal) bool { |
|
cmp := d.Cmp(d2) |
|
return cmp == -1 || cmp == 0 |
|
} |
|
|
|
// Sign returns: |
|
// |
|
// -1 if d < 0 |
|
// 0 if d == 0 |
|
// +1 if d > 0 |
|
// |
|
func (d Decimal) Sign() int { |
|
if d.value == nil { |
|
return 0 |
|
} |
|
return d.value.Sign() |
|
} |
|
|
|
// Exponent returns the exponent, or scale component of the decimal. |
|
func (d Decimal) Exponent() int32 { |
|
return d.exp |
|
} |
|
|
|
// Coefficient returns the coefficient of the decimal. It is scaled by 10^Exponent() |
|
func (d Decimal) Coefficient() *big.Int { |
|
// we copy the coefficient so that mutating the result does not mutate the |
|
// Decimal. |
|
return big.NewInt(0).Set(d.value) |
|
} |
|
|
|
// IntPart returns the integer component of the decimal. |
|
func (d Decimal) IntPart() int64 { |
|
scaledD := d.rescale(0) |
|
return scaledD.value.Int64() |
|
} |
|
|
|
// Rat returns a rational number representation of the decimal. |
|
func (d Decimal) Rat() *big.Rat { |
|
d.ensureInitialized() |
|
if d.exp <= 0 { |
|
// NOTE(vadim): must negate after casting to prevent int32 overflow |
|
denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil) |
|
return new(big.Rat).SetFrac(d.value, denom) |
|
} |
|
|
|
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil) |
|
num := new(big.Int).Mul(d.value, mul) |
|
return new(big.Rat).SetFrac(num, oneInt) |
|
} |
|
|
|
// Float64 returns the nearest float64 value for d and a bool indicating |
|
// whether f represents d exactly. |
|
// For more details, see the documentation for big.Rat.Float64 |
|
func (d Decimal) Float64() (f float64, exact bool) { |
|
return d.Rat().Float64() |
|
} |
|
|
|
// String returns the string representation of the decimal |
|
// with the fixed point. |
|
// |
|
// Example: |
|
// |
|
// d := New(-12345, -3) |
|
// println(d.String()) |
|
// |
|
// Output: |
|
// |
|
// -12.345 |
|
// |
|
func (d Decimal) String() string { |
|
return d.string(true) |
|
} |
|
|
|
// StringFixed returns a rounded fixed-point string with places digits after |
|
// the decimal point. |
|
// |
|
// Example: |
|
// |
|
// NewFromFloat(0).StringFixed(2) // output: "0.00" |
|
// NewFromFloat(0).StringFixed(0) // output: "0" |
|
// NewFromFloat(5.45).StringFixed(0) // output: "5" |
|
// NewFromFloat(5.45).StringFixed(1) // output: "5.5" |
|
// NewFromFloat(5.45).StringFixed(2) // output: "5.45" |
|
// NewFromFloat(5.45).StringFixed(3) // output: "5.450" |
|
// NewFromFloat(545).StringFixed(-1) // output: "550" |
|
// |
|
func (d Decimal) StringFixed(places int32) string { |
|
rounded := d.Round(places) |
|
return rounded.string(false) |
|
} |
|
|
|
// StringFixedBank returns a banker rounded fixed-point string with places digits |
|
// after the decimal point. |
|
// |
|
// Example: |
|
// |
|
// NewFromFloat(0).StringFixed(2) // output: "0.00" |
|
// NewFromFloat(0).StringFixed(0) // output: "0" |
|
// NewFromFloat(5.45).StringFixed(0) // output: "5" |
|
// NewFromFloat(5.45).StringFixed(1) // output: "5.4" |
|
// NewFromFloat(5.45).StringFixed(2) // output: "5.45" |
|
// NewFromFloat(5.45).StringFixed(3) // output: "5.450" |
|
// NewFromFloat(545).StringFixed(-1) // output: "550" |
|
// |
|
func (d Decimal) StringFixedBank(places int32) string { |
|
rounded := d.RoundBank(places) |
|
return rounded.string(false) |
|
} |
|
|
|
// StringFixedCash returns a Swedish/Cash rounded fixed-point string. For |
|
// more details see the documentation at function RoundCash. |
|
func (d Decimal) StringFixedCash(interval uint8) string { |
|
rounded := d.RoundCash(interval) |
|
return rounded.string(false) |
|
} |
|
|
|
// Round rounds the decimal to places decimal places. |
|
// If places < 0, it will round the integer part to the nearest 10^(-places). |
|
// |
|
// Example: |
|
// |
|
// NewFromFloat(5.45).Round(1).String() // output: "5.5" |
|
// NewFromFloat(545).Round(-1).String() // output: "550" |
|
// |
|
func (d Decimal) Round(places int32) Decimal { |
|
// truncate to places + 1 |
|
ret := d.rescale(-places - 1) |
|
|
|
// add sign(d) * 0.5 |
|
if ret.value.Sign() < 0 { |
|
ret.value.Sub(ret.value, fiveInt) |
|
} else { |
|
ret.value.Add(ret.value, fiveInt) |
|
} |
|
|
|
// floor for positive numbers, ceil for negative numbers |
|
_, m := ret.value.DivMod(ret.value, tenInt, new(big.Int)) |
|
ret.exp++ |
|
if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 { |
|
ret.value.Add(ret.value, oneInt) |
|
} |
|
|
|
return ret |
|
} |
|
|
|
// RoundBank rounds the decimal to places decimal places. |
|
// If the final digit to round is equidistant from the nearest two integers the |
|
// rounded value is taken as the even number |
|
// |
|
// If places < 0, it will round the integer part to the nearest 10^(-places). |
|
// |
|
// Examples: |
|
// |
|
// NewFromFloat(5.45).Round(1).String() // output: "5.4" |
|
// NewFromFloat(545).Round(-1).String() // output: "540" |
|
// NewFromFloat(5.46).Round(1).String() // output: "5.5" |
|
// NewFromFloat(546).Round(-1).String() // output: "550" |
|
// NewFromFloat(5.55).Round(1).String() // output: "5.6" |
|
// NewFromFloat(555).Round(-1).String() // output: "560" |
|
// |
|
func (d Decimal) RoundBank(places int32) Decimal { |
|
|
|
round := d.Round(places) |
|
remainder := d.Sub(round).Abs() |
|
|
|
half := New(5, -places-1) |
|
if remainder.Cmp(half) == 0 && round.value.Bit(0) != 0 { |
|
if round.value.Sign() < 0 { |
|
round.value.Add(round.value, oneInt) |
|
} else { |
|
round.value.Sub(round.value, oneInt) |
|
} |
|
} |
|
|
|
return round |
|
} |
|
|
|
// RoundCash aka Cash/Penny/öre rounding rounds decimal to a specific |
|
// interval. The amount payable for a cash transaction is rounded to the nearest |
|
// multiple of the minimum currency unit available. The following intervals are |
|
// available: 5, 10, 15, 25, 50 and 100; any other number throws a panic. |
|
// 5: 5 cent rounding 3.43 => 3.45 |
|
// 10: 10 cent rounding 3.45 => 3.50 (5 gets rounded up) |
|
// 15: 10 cent rounding 3.45 => 3.40 (5 gets rounded down) |
|
// 25: 25 cent rounding 3.41 => 3.50 |
|
// 50: 50 cent rounding 3.75 => 4.00 |
|
// 100: 100 cent rounding 3.50 => 4.00 |
|
// For more details: https://en.wikipedia.org/wiki/Cash_rounding |
|
func (d Decimal) RoundCash(interval uint8) Decimal { |
|
var iVal *big.Int |
|
switch interval { |
|
case 5: |
|
iVal = twentyInt |
|
case 10: |
|
iVal = tenInt |
|
case 15: |
|
if d.exp < 0 { |
|
// TODO: optimize and reduce allocations |
|
orgExp := d.exp |
|
dOne := New(10^-int64(orgExp), orgExp) |
|
d2 := d |
|
d2.exp = 0 |
|
if d2.Mod(fiveDec).Equal(Zero) { |
|
d2.exp = orgExp |
|
d2 = d2.Sub(dOne) |
|
d = d2 |
|
} |
|
} |
|
iVal = tenInt |
|
case 25: |
|
iVal = fourInt |
|
case 50: |
|
iVal = twoInt |
|
case 100: |
|
iVal = oneInt |
|
default: |
|
panic(fmt.Sprintf("Decimal does not support this Cash rounding interval `%d`. Supported: 5, 10, 15, 25, 50, 100", interval)) |
|
} |
|
dVal := Decimal{ |
|
value: iVal, |
|
} |
|
// TODO: optimize those calculations to reduce the high allocations (~29 allocs). |
|
return d.Mul(dVal).Round(0).Div(dVal).Truncate(2) |
|
} |
|
|
|
// Floor returns the nearest integer value less than or equal to d. |
|
func (d Decimal) Floor() Decimal { |
|
d.ensureInitialized() |
|
|
|
if d.exp >= 0 { |
|
return d |
|
} |
|
|
|
exp := big.NewInt(10) |
|
|
|
// NOTE(vadim): must negate after casting to prevent int32 overflow |
|
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil) |
|
|
|
z := new(big.Int).Div(d.value, exp) |
|
return Decimal{value: z, exp: 0} |
|
} |
|
|
|
// Ceil returns the nearest integer value greater than or equal to d. |
|
func (d Decimal) Ceil() Decimal { |
|
d.ensureInitialized() |
|
|
|
if d.exp >= 0 { |
|
return d |
|
} |
|
|
|
exp := big.NewInt(10) |
|
|
|
// NOTE(vadim): must negate after casting to prevent int32 overflow |
|
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil) |
|
|
|
z, m := new(big.Int).DivMod(d.value, exp, new(big.Int)) |
|
if m.Cmp(zeroInt) != 0 { |
|
z.Add(z, oneInt) |
|
} |
|
return Decimal{value: z, exp: 0} |
|
} |
|
|
|
// Truncate truncates off digits from the number, without rounding. |
|
// |
|
// NOTE: precision is the last digit that will not be truncated (must be >= 0). |
|
// |
|
// Example: |
|
// |
|
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45" |
|
// |
|
func (d Decimal) Truncate(precision int32) Decimal { |
|
d.ensureInitialized() |
|
if precision >= 0 && -precision > d.exp { |
|
return d.rescale(-precision) |
|
} |
|
return d |
|
} |
|
|
|
// UnmarshalJSON implements the json.Unmarshaler interface. |
|
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { |
|
if string(decimalBytes) == "null" { |
|
return nil |
|
} |
|
|
|
str, err := unquoteIfQuoted(decimalBytes) |
|
if err != nil { |
|
return fmt.Errorf("Error decoding string '%s': %s", decimalBytes, err) |
|
} |
|
|
|
decimal, err := NewFromString(str) |
|
*d = decimal |
|
if err != nil { |
|
return fmt.Errorf("Error decoding string '%s': %s", str, err) |
|
} |
|
return nil |
|
} |
|
|
|
// MarshalJSON implements the json.Marshaler interface. |
|
func (d Decimal) MarshalJSON() ([]byte, error) { |
|
var str string |
|
if MarshalJSONWithoutQuotes { |
|
str = d.String() |
|
} else { |
|
str = "\"" + d.String() + "\"" |
|
} |
|
return []byte(str), nil |
|
} |
|
|
|
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation |
|
// is already used when encoding to text, this method stores that string as []byte |
|
func (d *Decimal) UnmarshalBinary(data []byte) error { |
|
// Extract the exponent |
|
d.exp = int32(binary.BigEndian.Uint32(data[:4])) |
|
|
|
// Extract the value |
|
d.value = new(big.Int) |
|
return d.value.GobDecode(data[4:]) |
|
} |
|
|
|
// MarshalBinary implements the encoding.BinaryMarshaler interface. |
|
func (d Decimal) MarshalBinary() (data []byte, err error) { |
|
// Write the exponent first since it's a fixed size |
|
v1 := make([]byte, 4) |
|
binary.BigEndian.PutUint32(v1, uint32(d.exp)) |
|
|
|
// Add the value |
|
var v2 []byte |
|
if v2, err = d.value.GobEncode(); err != nil { |
|
return |
|
} |
|
|
|
// Return the byte array |
|
data = append(v1, v2...) |
|
return |
|
} |
|
|
|
// Scan implements the sql.Scanner interface for database deserialization. |
|
func (d *Decimal) Scan(value interface{}) error { |
|
// first try to see if the data is stored in database as a Numeric datatype |
|
switch v := value.(type) { |
|
|
|
case float32: |
|
*d = NewFromFloat(float64(v)) |
|
return nil |
|
|
|
case float64: |
|
// numeric in sqlite3 sends us float64 |
|
*d = NewFromFloat(v) |
|
return nil |
|
|
|
case int64: |
|
// at least in sqlite3 when the value is 0 in db, the data is sent |
|
// to us as an int64 instead of a float64 ... |
|
*d = New(v, 0) |
|
return nil |
|
|
|
default: |
|
// default is trying to interpret value stored as string |
|
str, err := unquoteIfQuoted(v) |
|
if err != nil { |
|
return err |
|
} |
|
*d, err = NewFromString(str) |
|
return err |
|
} |
|
} |
|
|
|
// Value implements the driver.Valuer interface for database serialization. |
|
func (d Decimal) Value() (driver.Value, error) { |
|
return d.String(), nil |
|
} |
|
|
|
// UnmarshalText implements the encoding.TextUnmarshaler interface for XML |
|
// deserialization. |
|
func (d *Decimal) UnmarshalText(text []byte) error { |
|
str := string(text) |
|
|
|
dec, err := NewFromString(str) |
|
*d = dec |
|
if err != nil { |
|
return fmt.Errorf("Error decoding string '%s': %s", str, err) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// MarshalText implements the encoding.TextMarshaler interface for XML |
|
// serialization. |
|
func (d Decimal) MarshalText() (text []byte, err error) { |
|
return []byte(d.String()), nil |
|
} |
|
|
|
// GobEncode implements the gob.GobEncoder interface for gob serialization. |
|
func (d Decimal) GobEncode() ([]byte, error) { |
|
return d.MarshalBinary() |
|
} |
|
|
|
// GobDecode implements the gob.GobDecoder interface for gob serialization. |
|
func (d *Decimal) GobDecode(data []byte) error { |
|
return d.UnmarshalBinary(data) |
|
} |
|
|
|
// StringScaled first scales the decimal then calls .String() on it. |
|
// NOTE: buggy, unintuitive, and DEPRECATED! Use StringFixed instead. |
|
func (d Decimal) StringScaled(exp int32) string { |
|
return d.rescale(exp).String() |
|
} |
|
|
|
func (d Decimal) string(trimTrailingZeros bool) string { |
|
if d.exp >= 0 { |
|
return d.rescale(0).value.String() |
|
} |
|
|
|
abs := new(big.Int).Abs(d.value) |
|
str := abs.String() |
|
|
|
var intPart, fractionalPart string |
|
|
|
// NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN |
|
// and you are on a 32-bit machine. Won't fix this super-edge case. |
|
dExpInt := int(d.exp) |
|
if len(str) > -dExpInt { |
|
intPart = str[:len(str)+dExpInt] |
|
fractionalPart = str[len(str)+dExpInt:] |
|
} else { |
|
intPart = "0" |
|
|
|
num0s := -dExpInt - len(str) |
|
fractionalPart = strings.Repeat("0", num0s) + str |
|
} |
|
|
|
if trimTrailingZeros { |
|
i := len(fractionalPart) - 1 |
|
for ; i >= 0; i-- { |
|
if fractionalPart[i] != '0' { |
|
break |
|
} |
|
} |
|
fractionalPart = fractionalPart[:i+1] |
|
} |
|
|
|
number := intPart |
|
if len(fractionalPart) > 0 { |
|
number += "." + fractionalPart |
|
} |
|
|
|
if d.value.Sign() < 0 { |
|
return "-" + number |
|
} |
|
|
|
return number |
|
} |
|
|
|
func (d *Decimal) ensureInitialized() { |
|
if d.value == nil { |
|
d.value = new(big.Int) |
|
} |
|
} |
|
|
|
// Min returns the smallest Decimal that was passed in the arguments. |
|
// |
|
// To call this function with an array, you must do: |
|
// |
|
// Min(arr[0], arr[1:]...) |
|
// |
|
// This makes it harder to accidentally call Min with 0 arguments. |
|
func Min(first Decimal, rest ...Decimal) Decimal { |
|
ans := first |
|
for _, item := range rest { |
|
if item.Cmp(ans) < 0 { |
|
ans = item |
|
} |
|
} |
|
return ans |
|
} |
|
|
|
// Max returns the largest Decimal that was passed in the arguments. |
|
// |
|
// To call this function with an array, you must do: |
|
// |
|
// Max(arr[0], arr[1:]...) |
|
// |
|
// This makes it harder to accidentally call Max with 0 arguments. |
|
func Max(first Decimal, rest ...Decimal) Decimal { |
|
ans := first |
|
for _, item := range rest { |
|
if item.Cmp(ans) > 0 { |
|
ans = item |
|
} |
|
} |
|
return ans |
|
} |
|
|
|
// Sum returns the combined total of the provided first and rest Decimals |
|
func Sum(first Decimal, rest ...Decimal) Decimal { |
|
total := first |
|
for _, item := range rest { |
|
total = total.Add(item) |
|
} |
|
|
|
return total |
|
} |
|
|
|
// Avg returns the average value of the provided first and rest Decimals |
|
func Avg(first Decimal, rest ...Decimal) Decimal { |
|
count := New(int64(len(rest)+1), 0) |
|
sum := Sum(first, rest...) |
|
return sum.Div(count) |
|
} |
|
|
|
func min(x, y int32) int32 { |
|
if x >= y { |
|
return y |
|
} |
|
return x |
|
} |
|
|
|
func unquoteIfQuoted(value interface{}) (string, error) { |
|
var bytes []byte |
|
|
|
switch v := value.(type) { |
|
case string: |
|
bytes = []byte(v) |
|
case []byte: |
|
bytes = v |
|
default: |
|
return "", fmt.Errorf("Could not convert value '%+v' to byte array of type '%T'", |
|
value, value) |
|
} |
|
|
|
// If the amount is quoted, strip the quotes |
|
if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' { |
|
bytes = bytes[1 : len(bytes)-1] |
|
} |
|
return string(bytes), nil |
|
} |
|
|
|
// NullDecimal represents a nullable decimal with compatibility for |
|
// scanning null values from the database. |
|
type NullDecimal struct { |
|
Decimal Decimal |
|
Valid bool |
|
} |
|
|
|
// Scan implements the sql.Scanner interface for database deserialization. |
|
func (d *NullDecimal) Scan(value interface{}) error { |
|
if value == nil { |
|
d.Valid = false |
|
return nil |
|
} |
|
d.Valid = true |
|
return d.Decimal.Scan(value) |
|
} |
|
|
|
// Value implements the driver.Valuer interface for database serialization. |
|
func (d NullDecimal) Value() (driver.Value, error) { |
|
if !d.Valid { |
|
return nil, nil |
|
} |
|
return d.Decimal.Value() |
|
} |
|
|
|
// UnmarshalJSON implements the json.Unmarshaler interface. |
|
func (d *NullDecimal) UnmarshalJSON(decimalBytes []byte) error { |
|
if string(decimalBytes) == "null" { |
|
d.Valid = false |
|
return nil |
|
} |
|
d.Valid = true |
|
return d.Decimal.UnmarshalJSON(decimalBytes) |
|
} |
|
|
|
// MarshalJSON implements the json.Marshaler interface. |
|
func (d NullDecimal) MarshalJSON() ([]byte, error) { |
|
if !d.Valid { |
|
return []byte("null"), nil |
|
} |
|
return d.Decimal.MarshalJSON() |
|
}
|
|
|