From 0141a02160ee7b91642d3c3475959d6742ccd693 Mon Sep 17 00:00:00 2001 From: Cameron Moore Date: Tue, 12 Apr 2016 20:31:14 -0500 Subject: [PATCH] tpl: Extend where to iterate over maps Refactor and extend where to iterate over maps. --- tpl/template_funcs.go | 137 ++++++++++++++++++++++++++----------- tpl/template_funcs_test.go | 22 ++++-- 2 files changed, 115 insertions(+), 44 deletions(-) diff --git a/tpl/template_funcs.go b/tpl/template_funcs.go index 9131cc1ec..f6913510e 100644 --- a/tpl/template_funcs.go +++ b/tpl/template_funcs.go @@ -37,13 +37,11 @@ import ( "time" "unicode/utf8" - "github.com/spf13/afero" - "github.com/spf13/hugo/hugofs" - "github.com/bep/inflect" - + "github.com/spf13/afero" "github.com/spf13/cast" "github.com/spf13/hugo/helpers" + "github.com/spf13/hugo/hugofs" jww "github.com/spf13/jwalterweatherman" ) @@ -771,64 +769,125 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) { return false, nil } -// where returns a filtered subset of a given data type. -func where(seq, key interface{}, args ...interface{}) (r interface{}, err error) { - seqv := reflect.ValueOf(seq) - kv := reflect.ValueOf(key) - - var mv reflect.Value - var op string +// parseWhereArgs parses the end arguments to the where function. Return a +// match value and an operator, if one is defined. +func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) { switch len(args) { case 1: mv = reflect.ValueOf(args[0]) case 2: var ok bool if op, ok = args[0].(string); !ok { - return nil, errors.New("operator argument must be string type") + err = errors.New("operator argument must be string type") + return } op = strings.TrimSpace(strings.ToLower(op)) mv = reflect.ValueOf(args[1]) default: - return nil, errors.New("can't evaluate the array by no match argument or more than or equal to two arguments") + err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments") } + return +} - seqv, isNil := indirect(seqv) +// checkWhereArray handles the where-matching logic when the seqv value is an +// Array or Slice. +func checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) { + rv := reflect.MakeSlice(seqv.Type(), 0, 0) + for i := 0; i < seqv.Len(); i++ { + var vvv reflect.Value + rvv := seqv.Index(i) + if kv.Kind() == reflect.String { + vvv = rvv + for _, elemName := range path { + var err error + vvv, err = evaluateSubElem(vvv, elemName) + if err != nil { + return nil, err + } + } + } else { + vv, _ := indirect(rvv) + if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) { + vvv = vv.MapIndex(kv) + } + } + + if ok, err := checkCondition(vvv, mv, op); ok { + rv = reflect.Append(rv, rvv) + } else if err != nil { + return nil, err + } + } + return rv.Interface(), nil +} + +// checkWhereMap handles the where-matching logic when the seqv value is a Map. +func checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) { + rv := reflect.MakeMap(seqv.Type()) + keys := seqv.MapKeys() + for _, k := range keys { + elemv := seqv.MapIndex(k) + switch elemv.Kind() { + case reflect.Array, reflect.Slice: + r, err := checkWhereArray(elemv, kv, mv, path, op) + if err != nil { + return nil, err + } + + switch rr := reflect.ValueOf(r); rr.Kind() { + case reflect.Slice: + if rr.Len() > 0 { + rv.SetMapIndex(k, elemv) + } + } + case reflect.Interface: + elemvv, isNil := indirect(elemv) + if isNil { + continue + } + + switch elemvv.Kind() { + case reflect.Array, reflect.Slice: + r, err := checkWhereArray(elemvv, kv, mv, path, op) + if err != nil { + return nil, err + } + + switch rr := reflect.ValueOf(r); rr.Kind() { + case reflect.Slice: + if rr.Len() > 0 { + rv.SetMapIndex(k, elemv) + } + } + } + } + } + return rv, nil +} + +// where returns a filtered subset of a given data type. +func where(seq, key interface{}, args ...interface{}) (interface{}, error) { + seqv, isNil := indirect(reflect.ValueOf(seq)) if isNil { return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String()) } + mv, op, err := parseWhereArgs(args...) + if err != nil { + return nil, err + } + var path []string + kv := reflect.ValueOf(key) if kv.Kind() == reflect.String { path = strings.Split(strings.Trim(kv.String(), "."), ".") } switch seqv.Kind() { case reflect.Array, reflect.Slice: - rv := reflect.MakeSlice(seqv.Type(), 0, 0) - for i := 0; i < seqv.Len(); i++ { - var vvv reflect.Value - rvv := seqv.Index(i) - if kv.Kind() == reflect.String { - vvv = rvv - for _, elemName := range path { - vvv, err = evaluateSubElem(vvv, elemName) - if err != nil { - return nil, err - } - } - } else { - vv, _ := indirect(rvv) - if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) { - vvv = vv.MapIndex(kv) - } - } - if ok, err := checkCondition(vvv, mv, op); ok { - rv = reflect.Append(rv, rvv) - } else if err != nil { - return nil, err - } - } - return rv.Interface(), nil + return checkWhereArray(seqv, kv, mv, path, op) + case reflect.Map: + return checkWhereMap(seqv, kv, mv, path, op) default: return nil, fmt.Errorf("can't iterate over %v", seq) } diff --git a/tpl/template_funcs_test.go b/tpl/template_funcs_test.go index 8d604e817..5dbcf6cf6 100644 --- a/tpl/template_funcs_test.go +++ b/tpl/template_funcs_test.go @@ -18,11 +18,6 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/spf13/afero" - "github.com/spf13/cast" - "github.com/spf13/hugo/hugofs" - "github.com/spf13/viper" - "github.com/stretchr/testify/assert" "html/template" "math/rand" "path" @@ -32,6 +27,12 @@ import ( "strings" "testing" "time" + + "github.com/spf13/afero" + "github.com/spf13/cast" + "github.com/spf13/hugo/hugofs" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" ) type tstNoStringer struct { @@ -1298,6 +1299,17 @@ func TestWhere(t *testing.T) { key: "B", op: "op", match: "f", expect: false, }, + { + sequence: map[string]interface{}{ + "foo": []interface{}{map[interface{}]interface{}{"a": 1, "b": 2}}, + "bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}}, + "zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}}, + }, + key: "b", op: "in", match: slice(3, 4, 5), + expect: map[string]interface{}{ + "bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}}, + }, + }, } { var results interface{} var err error