tpl/collections: Make Pages etc. work with the in func

Fixes #5875
This commit is contained in:
Bjørn Erik Pedersen 2019-04-18 17:06:54 +02:00
parent d7a67dcb51
commit 06f56fc983
2 changed files with 23 additions and 17 deletions

View file

@ -250,27 +250,26 @@ func (ns *Namespace) In(l interface{}, v interface{}) bool {
lv := reflect.ValueOf(l) lv := reflect.ValueOf(l)
vv := reflect.ValueOf(v) vv := reflect.ValueOf(v)
if !vv.Type().Comparable() {
// TODO(bep) consider adding error to the signature.
return false
}
// Normalize numeric types to float64 etc.
vvk := normalize(vv)
switch lv.Kind() { switch lv.Kind() {
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
for i := 0; i < lv.Len(); i++ { for i := 0; i < lv.Len(); i++ {
lvv := lv.Index(i) lvv, isNil := indirectInterface(lv.Index(i))
lvv, isNil := indirect(lvv) if isNil || !lvv.Type().Comparable() {
if isNil {
continue continue
} }
switch lvv.Kind() {
case reflect.String: lvvk := normalize(lvv)
if vv.Type() == lvv.Type() && vv.String() == lvv.String() {
return true if lvvk == vvk {
} return true
default:
if isNumber(vv.Kind()) && isNumber(lvv.Kind()) {
f1, err1 := numberToFloat(vv)
f2, err2 := numberToFloat(lvv)
if err1 == nil && err2 == nil && f1 == f2 {
return true
}
}
} }
} }
case reflect.String: case reflect.String:

View file

@ -276,6 +276,7 @@ func TestFirst(t *testing.T) {
func TestIn(t *testing.T) { func TestIn(t *testing.T) {
t.Parallel() t.Parallel()
assert := require.New(t)
ns := New(&deps.Deps{}) ns := New(&deps.Deps{})
@ -302,12 +303,18 @@ func TestIn(t *testing.T) {
{"this substring should be found", "substring", true}, {"this substring should be found", "substring", true},
{"this substring should not be found", "subseastring", false}, {"this substring should not be found", "subseastring", false},
{nil, "foo", false}, {nil, "foo", false},
// Pointers
{pagesPtr{p1, p2, p3, p2}, p2, true},
{pagesPtr{p1, p2, p3, p2}, p4, false},
// Structs
{pagesVals{p3v, p2v, p3v, p2v}, p2v, true},
{pagesVals{p3v, p2v, p3v, p2v}, p4v, false},
} { } {
errMsg := fmt.Sprintf("[%d] %v", i, test) errMsg := fmt.Sprintf("[%d] %v", i, test)
result := ns.In(test.l1, test.l2) result := ns.In(test.l1, test.l2)
assert.Equal(t, test.expect, result, errMsg) assert.Equal(test.expect, result, errMsg)
} }
} }