Support unComparable args of uniq/complement/in

Fixes #6105
This commit is contained in:
satotake 2020-03-09 21:32:38 +09:00 committed by GitHub
parent c4fa2f0799
commit 8279d2e227
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 30 additions and 24 deletions

View file

@ -271,18 +271,13 @@ func (ns *Namespace) In(l interface{}, v interface{}) (bool, error) {
lv := reflect.ValueOf(l) lv := reflect.ValueOf(l)
vv := reflect.ValueOf(v) vv := reflect.ValueOf(v)
if !vv.Type().Comparable() {
return false, errors.Errorf("value to check must be comparable: %T", v)
}
// Normalize numeric types to float64 etc.
vvk := normalize(vv) vvk := normalize(vv)
switch lv.Kind() { switch lv.Kind() {
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
for i := 0; i < lv.Len(); i++ { for i := 0; i < lv.Len(); i++ {
lvv, isNil := indirectInterface(lv.Index(i)) lvv, isNil := indirectInterface(lv.Index(i))
if isNil || !lvv.Type().Comparable() { if isNil {
continue continue
} }
@ -713,6 +708,7 @@ func (ns *Namespace) Uniq(seq interface{}) (interface{}, error) {
switch v.Kind() { switch v.Kind() {
case reflect.Slice: case reflect.Slice:
slice = reflect.MakeSlice(v.Type(), 0, 0) slice = reflect.MakeSlice(v.Type(), 0, 0)
case reflect.Array: case reflect.Array:
slice = reflect.MakeSlice(reflect.SliceOf(v.Type().Elem()), 0, 0) slice = reflect.MakeSlice(reflect.SliceOf(v.Type().Elem()), 0, 0)
default: default:
@ -720,12 +716,12 @@ func (ns *Namespace) Uniq(seq interface{}) (interface{}, error) {
} }
seen := make(map[interface{}]bool) seen := make(map[interface{}]bool)
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
ev, _ := indirectInterface(v.Index(i)) ev, _ := indirectInterface(v.Index(i))
if !ev.Type().Comparable() {
return nil, errors.New("elements must be comparable")
}
key := normalize(ev) key := normalize(ev)
if _, found := seen[key]; !found { if _, found := seen[key]; !found {
slice = reflect.Append(slice, ev) slice = reflect.Append(slice, ev)
seen[key] = true seen[key] = true

View file

@ -348,6 +348,9 @@ func TestIn(t *testing.T) {
// template.HTML // template.HTML
{template.HTML("this substring should be found"), "substring", true}, {template.HTML("this substring should be found"), "substring", true},
{template.HTML("this substring should not be found"), "subseastring", false}, {template.HTML("this substring should not be found"), "subseastring", false},
// Uncomparable, use hashstructure
{[]string{"a", "b"}, []string{"a", "b"}, false},
{[][]string{{"a", "b"}}, []string{"a", "b"}, true},
} { } {
errMsg := qt.Commentf("[%d] %v", i, test) errMsg := qt.Commentf("[%d] %v", i, test)
@ -356,10 +359,6 @@ func TestIn(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Assert(result, qt.Equals, test.expect, errMsg) c.Assert(result, qt.Equals, test.expect, errMsg)
} }
// Slices are not comparable
_, err := ns.In([]string{"a", "b"}, []string{"a", "b"})
c.Assert(err, qt.Not(qt.IsNil))
} }
type testPage struct { type testPage struct {
@ -835,9 +834,14 @@ func TestUniq(t *testing.T) {
// Structs // Structs
{pagesVals{p3v, p2v, p3v, p2v}, pagesVals{p3v, p2v}, false}, {pagesVals{p3v, p2v, p3v, p2v}, pagesVals{p3v, p2v}, false},
// not Comparable(), use hashstruscture
{[]map[string]int{
{"K1": 1}, {"K2": 2}, {"K1": 1}, {"K2": 1},
}, []map[string]int{
{"K1": 1}, {"K2": 2}, {"K2": 1},
}, false},
// should fail // should fail
// uncomparable types
{[]map[string]int{{"K1": 1}}, []map[string]int{{"K2": 2}, {"K2": 2}}, true},
{1, 1, true}, {1, 1, true},
{"foo", "fo", true}, {"foo", "fo", true},
} { } {

View file

@ -44,9 +44,6 @@ func (ns *Namespace) Complement(seqs ...interface{}) (interface{}, error) {
sl := reflect.MakeSlice(v.Type(), 0, 0) sl := reflect.MakeSlice(v.Type(), 0, 0)
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
ev, _ := indirectInterface(v.Index(i)) ev, _ := indirectInterface(v.Index(i))
if !ev.Type().Comparable() {
return nil, errors.New("elements in complement must be comparable")
}
if _, found := aset[normalize(ev)]; !found { if _, found := aset[normalize(ev)]; !found {
sl = reflect.Append(sl, ev) sl = reflect.Append(sl, ev)
} }

View file

@ -65,7 +65,10 @@ func TestComplement(t *testing.T) {
{[]string{"a", "b", "c"}, []interface{}{"error"}, false}, {[]string{"a", "b", "c"}, []interface{}{"error"}, false},
{"error", []interface{}{[]string{"c", "d"}, []string{"a", "b"}}, false}, {"error", []interface{}{[]string{"c", "d"}, []string{"a", "b"}}, false},
{[]string{"a", "b", "c"}, []interface{}{[][]string{{"c", "d"}}}, false}, {[]string{"a", "b", "c"}, []interface{}{[][]string{{"c", "d"}}}, false},
{[]interface{}{[][]string{{"c", "d"}}}, []interface{}{[]string{"c", "d"}, []string{"a", "b"}}, false}, {
[]interface{}{[][]string{{"c", "d"}}}, []interface{}{[]string{"c", "d"}, []string{"a", "b"}},
[]interface{}{[][]string{{"c", "d"}}},
},
} { } {
errMsg := qt.Commentf("[%d]", i) errMsg := qt.Commentf("[%d]", i)

View file

@ -18,6 +18,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/mitchellh/hashstructure"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -42,18 +43,25 @@ func numberToFloat(v reflect.Value) (float64, error) {
} }
} }
// normalizes different numeric types to make them comparable. // normalizes different numeric types if isNumber
// or get the hash values if not Comparable (such as map or struct)
// to make them comparable
func normalize(v reflect.Value) interface{} { func normalize(v reflect.Value) interface{} {
k := v.Kind() k := v.Kind()
switch { switch {
case !v.Type().Comparable():
h, err := hashstructure.Hash(v.Interface(), nil)
if err != nil {
panic(err)
}
return h
case isNumber(k): case isNumber(k):
f, err := numberToFloat(v) f, err := numberToFloat(v)
if err == nil { if err == nil {
return f return f
} }
} }
return v.Interface() return v.Interface()
} }

View file

@ -48,10 +48,8 @@ func (ns *Namespace) SymDiff(s2, s1 interface{}) (interface{}, error) {
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
ev, _ := indirectInterface(v.Index(i)) ev, _ := indirectInterface(v.Index(i))
if !ev.Type().Comparable() {
return nil, errors.New("symdiff: elements must be comparable")
}
key := normalize(ev) key := normalize(ev)
// Append if the key is not in their intersection. // Append if the key is not in their intersection.
if ids1[key] != ids2[key] { if ids1[key] != ids2[key] {
v, err := convertValue(ev, sliceElemType) v, err := convertValue(ev, sliceElemType)