diff --git a/tpl/collections/collections.go b/tpl/collections/collections.go index 0843fb7bc..081515ae5 100644 --- a/tpl/collections/collections.go +++ b/tpl/collections/collections.go @@ -298,21 +298,49 @@ func (ns *Namespace) Intersect(l1, l2 interface{}) (interface{}, error) { l2vv := l2v.Index(j) switch l1vv.Kind() { case reflect.String: - if l1vv.Type() == l2vv.Type() && l1vv.String() == l2vv.String() && !ns.In(r.Interface(), l2vv.Interface()) { - r = reflect.Append(r, l2vv) + l2t, err := toString(l2vv) + if err == nil && l1vv.String() == l2t && !ns.In(r.Interface(), l1vv.Interface()) { + r = reflect.Append(r, l1vv) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch l2vv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if l1vv.Int() == l2vv.Int() && !ns.In(r.Interface(), l2vv.Interface()) { - r = reflect.Append(r, l2vv) - } + l2t, err := toInt(l2vv) + if err == nil && l1vv.Int() == l2t && !ns.In(r.Interface(), l1vv.Interface()) { + r = reflect.Append(r, l1vv) } case reflect.Float32, reflect.Float64: - switch l2vv.Kind() { - case reflect.Float32, reflect.Float64: - if l1vv.Float() == l2vv.Float() && !ns.In(r.Interface(), l2vv.Interface()) { - r = reflect.Append(r, l2vv) + l2t, err := toFloat(l2vv) + if err == nil && l1vv.Float() == l2t && !ns.In(r.Interface(), l1vv.Interface()) { + r = reflect.Append(r, l1vv) + } + case reflect.Interface: + switch l1vvActual := l1vv.Interface().(type) { + case string: + switch l2vvActual := l2vv.Interface().(type) { + case string: + if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { + r = reflect.Append(r, l1vv) + } + } + case int, int8, int16, int32, int64: + switch l2vvActual := l2vv.Interface().(type) { + case int, int8, int16, int32, int64: + if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { + r = reflect.Append(r, l1vv) + } + } + case uint, uint8, uint16, uint32, uint64: + switch l2vvActual := l2vv.Interface().(type) { + case uint, uint8, uint16, uint32, uint64: + if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { + r = reflect.Append(r, l1vv) + } + } + case float32, float64: + switch l2vvActual := l2vv.Interface().(type) { + case float32, float64: + if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { + r = reflect.Append(r, l1vv) + } } } } diff --git a/tpl/collections/collections_test.go b/tpl/collections/collections_test.go index eefbcef6c..07055de86 100644 --- a/tpl/collections/collections_test.go +++ b/tpl/collections/collections_test.go @@ -260,7 +260,9 @@ func TestIntersect(t *testing.T) { {[]string{"a", "b"}, []string{"a", "b", "c"}, []string{"a", "b"}}, {[]string{"a", "b", "c"}, []string{"d", "e"}, []string{}}, {[]string{}, []string{}, []string{}}, - {nil, nil, make([]interface{}, 0)}, + {[]string{"a", "b"}, nil, []interface{}{}}, + {nil, []string{"a", "b"}, []interface{}{}}, + {nil, nil, []interface{}{}}, {[]string{"1", "2"}, []int{1, 2}, []string{}}, {[]int{1, 2}, []string{"1", "2"}, []int{}}, {[]int{1, 2, 4}, []int{2, 4}, []int{2, 4}}, @@ -270,6 +272,36 @@ func TestIntersect(t *testing.T) { // errors {"not array or slice", []string{"a"}, false}, {[]string{"a"}, "not array or slice", false}, + + // []interface{} ∩ []interface{} + {[]interface{}{"a", "b", "c"}, []interface{}{"a", "b", "b"}, []interface{}{"a", "b"}}, + {[]interface{}{1, 2, 3}, []interface{}{1, 2, 2}, []interface{}{1, 2}}, + {[]interface{}{int8(1), int8(2), int8(3)}, []interface{}{int8(1), int8(2), int8(2)}, []interface{}{int8(1), int8(2)}}, + {[]interface{}{int16(1), int16(2), int16(3)}, []interface{}{int16(1), int16(2), int16(2)}, []interface{}{int16(1), int16(2)}}, + {[]interface{}{int32(1), int32(2), int32(3)}, []interface{}{int32(1), int32(2), int32(2)}, []interface{}{int32(1), int32(2)}}, + {[]interface{}{int64(1), int64(2), int64(3)}, []interface{}{int64(1), int64(2), int64(2)}, []interface{}{int64(1), int64(2)}}, + {[]interface{}{float32(1), float32(2), float32(3)}, []interface{}{float32(1), float32(2), float32(2)}, []interface{}{float32(1), float32(2)}}, + {[]interface{}{float64(1), float64(2), float64(3)}, []interface{}{float64(1), float64(2), float64(2)}, []interface{}{float64(1), float64(2)}}, + + // []interface{} ∩ []T + {[]interface{}{"a", "b", "c"}, []string{"a", "b", "b"}, []interface{}{"a", "b"}}, + {[]interface{}{1, 2, 3}, []int{1, 2, 2}, []interface{}{1, 2}}, + {[]interface{}{int8(1), int8(2), int8(3)}, []int8{1, 2, 2}, []interface{}{int8(1), int8(2)}}, + {[]interface{}{int16(1), int16(2), int16(3)}, []int16{1, 2, 2}, []interface{}{int16(1), int16(2)}}, + {[]interface{}{int32(1), int32(2), int32(3)}, []int32{1, 2, 2}, []interface{}{int32(1), int32(2)}}, + {[]interface{}{int64(1), int64(2), int64(3)}, []int64{1, 2, 2}, []interface{}{int64(1), int64(2)}}, + {[]interface{}{float32(1), float32(2), float32(3)}, []float32{1, 2, 2}, []interface{}{float32(1), float32(2)}}, + {[]interface{}{float64(1), float64(2), float64(3)}, []float64{1, 2, 2}, []interface{}{float64(1), float64(2)}}, + + // []T ∩ []interface{} + {[]string{"a", "b", "c"}, []interface{}{"a", "b", "b"}, []string{"a", "b"}}, + {[]int{1, 2, 3}, []interface{}{1, 2, 2}, []int{1, 2}}, + {[]int8{1, 2, 3}, []interface{}{int8(1), int8(2), int8(2)}, []int8{1, 2}}, + {[]int16{1, 2, 3}, []interface{}{int16(1), int16(2), int16(2)}, []int16{1, 2}}, + {[]int32{1, 2, 3}, []interface{}{int32(1), int32(2), int32(2)}, []int32{1, 2}}, + {[]int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(2)}, []int64{1, 2}}, + {[]float32{1, 2, 3}, []interface{}{float32(1), float32(2), float32(2)}, []float32{1, 2}}, + {[]float64{1, 2, 3}, []interface{}{float64(1), float64(2), float64(2)}, []float64{1, 2}}, } { errMsg := fmt.Sprintf("[%d] %v", i, test) diff --git a/tpl/collections/where.go b/tpl/collections/where.go index f34494eb3..e9528fb86 100644 --- a/tpl/collections/where.go +++ b/tpl/collections/where.go @@ -124,16 +124,15 @@ func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error iv := v.Int() ivp = &iv for i := 0; i < mv.Len(); i++ { - if anInt := toInt(mv.Index(i)); anInt != -1 { + if anInt, err := toInt(mv.Index(i)); err == nil { ima = append(ima, anInt) } - } case reflect.String: sv := v.String() svp = &sv for i := 0; i < mv.Len(); i++ { - if aString := toString(mv.Index(i)); aString != "" { + if aString, err := toString(mv.Index(i)); err == nil { sma = append(sma, aString) } } @@ -382,26 +381,37 @@ func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op return rv.Interface(), nil } +// toFloat returns the int value if possible. +func toFloat(v reflect.Value) (float64, error) { + switch v.Kind() { + case reflect.Float32, reflect.Float64: + return v.Float(), nil + case reflect.Interface: + return toFloat(v.Elem()) + } + return -1, errors.New("unable to convert value to float") +} + // toInt returns the int value if possible, -1 if not. -func toInt(v reflect.Value) int64 { +func toInt(v reflect.Value) (int64, error) { switch v.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() + return v.Int(), nil case reflect.Interface: return toInt(v.Elem()) } - return -1 + return -1, errors.New("unable to convert value to int") } // toString returns the string value if possible, "" if not. -func toString(v reflect.Value) string { +func toString(v reflect.Value) (string, error) { switch v.Kind() { case reflect.String: - return v.String() + return v.String(), nil case reflect.Interface: return toString(v.Elem()) } - return "" + return "", errors.New("unable to convert value to string") } var (