diff --git a/tpl/template_funcs.go b/tpl/template_funcs.go index f6913510e..9131cc1ec 100644 --- a/tpl/template_funcs.go +++ b/tpl/template_funcs.go @@ -37,11 +37,13 @@ import ( "time" "unicode/utf8" - "github.com/bep/inflect" "github.com/spf13/afero" + "github.com/spf13/hugo/hugofs" + + "github.com/bep/inflect" + "github.com/spf13/cast" "github.com/spf13/hugo/helpers" - "github.com/spf13/hugo/hugofs" jww "github.com/spf13/jwalterweatherman" ) @@ -769,125 +771,64 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) { return false, nil } -// parseWhereArgs parses the end arguments to the where function. Return a -// match value and an operator, if one is defined. -func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) { +// where returns a filtered subset of a given data type. +func where(seq, key interface{}, args ...interface{}) (r interface{}, err error) { + seqv := reflect.ValueOf(seq) + kv := reflect.ValueOf(key) + + var mv reflect.Value + var op string switch len(args) { case 1: mv = reflect.ValueOf(args[0]) case 2: var ok bool if op, ok = args[0].(string); !ok { - err = errors.New("operator argument must be string type") - return + return nil, errors.New("operator argument must be string type") } op = strings.TrimSpace(strings.ToLower(op)) mv = reflect.ValueOf(args[1]) default: - err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments") + return nil, errors.New("can't evaluate the array by no match argument or more than or equal to two arguments") } - return -} -// checkWhereArray handles the where-matching logic when the seqv value is an -// Array or Slice. -func checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) { - rv := reflect.MakeSlice(seqv.Type(), 0, 0) - for i := 0; i < seqv.Len(); i++ { - var vvv reflect.Value - rvv := seqv.Index(i) - if kv.Kind() == reflect.String { - vvv = rvv - for _, elemName := range path { - var err error - vvv, err = evaluateSubElem(vvv, elemName) - if err != nil { - return nil, err - } - } - } else { - vv, _ := indirect(rvv) - if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) { - vvv = vv.MapIndex(kv) - } - } - - if ok, err := checkCondition(vvv, mv, op); ok { - rv = reflect.Append(rv, rvv) - } else if err != nil { - return nil, err - } - } - return rv.Interface(), nil -} - -// checkWhereMap handles the where-matching logic when the seqv value is a Map. -func checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) { - rv := reflect.MakeMap(seqv.Type()) - keys := seqv.MapKeys() - for _, k := range keys { - elemv := seqv.MapIndex(k) - switch elemv.Kind() { - case reflect.Array, reflect.Slice: - r, err := checkWhereArray(elemv, kv, mv, path, op) - if err != nil { - return nil, err - } - - switch rr := reflect.ValueOf(r); rr.Kind() { - case reflect.Slice: - if rr.Len() > 0 { - rv.SetMapIndex(k, elemv) - } - } - case reflect.Interface: - elemvv, isNil := indirect(elemv) - if isNil { - continue - } - - switch elemvv.Kind() { - case reflect.Array, reflect.Slice: - r, err := checkWhereArray(elemvv, kv, mv, path, op) - if err != nil { - return nil, err - } - - switch rr := reflect.ValueOf(r); rr.Kind() { - case reflect.Slice: - if rr.Len() > 0 { - rv.SetMapIndex(k, elemv) - } - } - } - } - } - return rv, nil -} - -// where returns a filtered subset of a given data type. -func where(seq, key interface{}, args ...interface{}) (interface{}, error) { - seqv, isNil := indirect(reflect.ValueOf(seq)) + seqv, isNil := indirect(seqv) if isNil { return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String()) } - mv, op, err := parseWhereArgs(args...) - if err != nil { - return nil, err - } - var path []string - kv := reflect.ValueOf(key) if kv.Kind() == reflect.String { path = strings.Split(strings.Trim(kv.String(), "."), ".") } switch seqv.Kind() { case reflect.Array, reflect.Slice: - return checkWhereArray(seqv, kv, mv, path, op) - case reflect.Map: - return checkWhereMap(seqv, kv, mv, path, op) + rv := reflect.MakeSlice(seqv.Type(), 0, 0) + for i := 0; i < seqv.Len(); i++ { + var vvv reflect.Value + rvv := seqv.Index(i) + if kv.Kind() == reflect.String { + vvv = rvv + for _, elemName := range path { + vvv, err = evaluateSubElem(vvv, elemName) + if err != nil { + return nil, err + } + } + } else { + vv, _ := indirect(rvv) + if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) { + vvv = vv.MapIndex(kv) + } + } + if ok, err := checkCondition(vvv, mv, op); ok { + rv = reflect.Append(rv, rvv) + } else if err != nil { + return nil, err + } + } + return rv.Interface(), nil default: return nil, fmt.Errorf("can't iterate over %v", seq) } diff --git a/tpl/template_funcs_test.go b/tpl/template_funcs_test.go index 5dbcf6cf6..8d604e817 100644 --- a/tpl/template_funcs_test.go +++ b/tpl/template_funcs_test.go @@ -18,6 +18,11 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/spf13/afero" + "github.com/spf13/cast" + "github.com/spf13/hugo/hugofs" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" "html/template" "math/rand" "path" @@ -27,12 +32,6 @@ import ( "strings" "testing" "time" - - "github.com/spf13/afero" - "github.com/spf13/cast" - "github.com/spf13/hugo/hugofs" - "github.com/spf13/viper" - "github.com/stretchr/testify/assert" ) type tstNoStringer struct { @@ -1299,17 +1298,6 @@ func TestWhere(t *testing.T) { key: "B", op: "op", match: "f", expect: false, }, - { - sequence: map[string]interface{}{ - "foo": []interface{}{map[interface{}]interface{}{"a": 1, "b": 2}}, - "bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}}, - "zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}}, - }, - key: "b", op: "in", match: slice(3, 4, 5), - expect: map[string]interface{}{ - "bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}}, - }, - }, } { var results interface{} var err error