diff --git a/tpl/math/math.go b/tpl/math/math.go index b79cb6188..534f7f284 100644 --- a/tpl/math/math.go +++ b/tpl/math/math.go @@ -72,21 +72,10 @@ func (ns *Namespace) Log(a interface{}) (float64, error) { // Mod returns a % b. func (ns *Namespace) Mod(a, b interface{}) (int64, error) { - av := reflect.ValueOf(a) - bv := reflect.ValueOf(b) - var ai, bi int64 + ai, erra := cast.ToInt64E(a) + bi, errb := cast.ToInt64E(b) - switch av.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - ai = av.Int() - default: - return 0, errors.New("Modulo operator can't be used with non integer value") - } - - switch bv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - bi = bv.Int() - default: + if erra != nil || errb != nil { return 0, errors.New("Modulo operator can't be used with non integer value") } diff --git a/tpl/math/math_test.go b/tpl/math/math_test.go index 4d14a58cc..97acfaeba 100644 --- a/tpl/math/math_test.go +++ b/tpl/math/math_test.go @@ -259,13 +259,17 @@ func TestMod(t *testing.T) { {3, 1, int64(0)}, {3, 0, false}, {0, 3, int64(0)}, - {3.1, 2, false}, - {3, 2.1, false}, - {3.1, 2.1, false}, + {3.1, 2, int64(1)}, + {3, 2.1, int64(1)}, + {3.1, 2.1, int64(1)}, {int8(3), int8(2), int64(1)}, {int16(3), int16(2), int64(1)}, {int32(3), int32(2), int64(1)}, {int64(3), int64(2), int64(1)}, + {"3", "2", int64(1)}, + {"3.1", "2", false}, + {"aaa", "0", false}, + {"3", "aaa", false}, } { errMsg := fmt.Sprintf("[%d] %v", i, test) @@ -296,9 +300,9 @@ func TestModBool(t *testing.T) { {3, 1, true}, {3, 0, nil}, {0, 3, true}, - {3.1, 2, nil}, - {3, 2.1, nil}, - {3.1, 2.1, nil}, + {3.1, 2, false}, + {3, 2.1, false}, + {3.1, 2.1, false}, {int8(3), int8(3), true}, {int8(3), int8(2), false}, {int16(3), int16(3), true}, @@ -307,6 +311,11 @@ func TestModBool(t *testing.T) { {int32(3), int32(2), false}, {int64(3), int64(3), true}, {int64(3), int64(2), false}, + {"3", "3", true}, + {"3", "2", false}, + {"3.1", "2", nil}, + {"aaa", "0", nil}, + {"3", "aaa", nil}, } { errMsg := fmt.Sprintf("[%d] %v", i, test)