diff --git a/hugolib/template.go b/hugolib/template.go index 1bf3fe110..23cb7a680 100644 --- a/hugolib/template.go +++ b/hugolib/template.go @@ -109,6 +109,71 @@ func First(limit int, seq interface{}) (interface{}, error) { return seqv.Slice(0, limit).Interface(), nil } +func Where(seq, key, match interface{}) (interface{}, error) { + seqv := reflect.ValueOf(seq) + kv := reflect.ValueOf(key) + mv := reflect.ValueOf(match) + + // this is better than my first pass; ripped from text/template/exec.go indirect(): + for ; seqv.Kind() == reflect.Ptr || seqv.Kind() == reflect.Interface; seqv = seqv.Elem() { + if seqv.IsNil() { + return nil, errors.New("can't iterate over a nil value") + } + if seqv.Kind() == reflect.Interface && seqv.NumMethod() > 0 { + break + } + } + + switch seqv.Kind() { + case reflect.Array, reflect.Slice: + r := reflect.MakeSlice(seqv.Type(), 0, 0) + for i := 0; i < seqv.Len(); i++ { + var vvv reflect.Value + vv := seqv.Index(i) + switch vv.Kind() { + case reflect.Map: + if kv.Type() == vv.Type().Key() && vv.MapIndex(kv).IsValid() { + vvv = vv.MapIndex(kv) + } + case reflect.Struct: + if kv.Kind() == reflect.String && vv.FieldByName(kv.String()).IsValid() { + vvv = vv.FieldByName(kv.String()) + } + case reflect.Ptr: + if !vv.IsNil() { + ev := vv.Elem() + switch ev.Kind() { + case reflect.Map: + if kv.Type() == ev.Type().Key() && ev.MapIndex(kv).IsValid() { + vvv = ev.MapIndex(kv) + } + case reflect.Struct: + if kv.Kind() == reflect.String && ev.FieldByName(kv.String()).IsValid() { + vvv = ev.FieldByName(kv.String()) + } + } + } + } + + if vvv.IsValid() && mv.Type() == vvv.Type() { + switch mv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if mv.Int() == vvv.Int() { + r = reflect.Append(r, vv) + } + case reflect.String: + if mv.String() == vvv.String() { + r = reflect.Append(r, vv) + } + } + } + } + return r.Interface(), nil + default: + return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String()) + } +} + func IsSet(a interface{}, key interface{}) bool { av := reflect.ValueOf(a) kv := reflect.ValueOf(key) @@ -211,6 +276,7 @@ func NewTemplate() Template { "echoParam": ReturnWhenSet, "safeHtml": SafeHtml, "first": First, + "where": Where, "highlight": Highlight, "add": func(a, b int) int { return a + b }, "sub": func(a, b int) int { return a - b }, diff --git a/hugolib/template_test.go b/hugolib/template_test.go index 029e2a49f..9a34e99de 100644 --- a/hugolib/template_test.go +++ b/hugolib/template_test.go @@ -55,3 +55,30 @@ func TestFirst(t *testing.T) { } } } + +func TestWhere(t *testing.T) { + type X struct { + A, B string + } + for i, this := range []struct { + sequence interface{} + key interface{} + match interface{} + expect interface{} + }{ + {[]map[int]string{{1: "a", 2: "m"}, {1: "c", 2: "d"}, {1: "e", 3: "m"}}, 2, "m", []map[int]string{{1: "a", 2: "m"}}}, + {[]map[string]int{{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "x": 4}}, "b", 4, []map[string]int{{"a": 3, "b": 4}}}, + {[]X{{"a", "b"}, {"c", "d"}, {"e", "f"}}, "B", "f", []X{{"e", "f"}}}, + {[]*map[int]string{&map[int]string{1: "a", 2: "m"}, &map[int]string{1: "c", 2: "d"}, &map[int]string{1: "e", 3: "m"}}, 2, "m", []*map[int]string{&map[int]string{1: "a", 2: "m"}}}, + {[]*X{&X{"a", "b"}, &X{"c", "d"}, &X{"e", "f"}}, "B", "f", []*X{&X{"e", "f"}}}, + } { + results, err := Where(this.sequence, this.key, this.match) + if err != nil { + t.Errorf("[%d] failed: %s", i, err) + continue + } + if !reflect.DeepEqual(results, this.expect) { + t.Errorf("[%d] Where clause matching %v with %v, got %v but expected %v", i, this.key, this.match, results, this.expect) + } + } +}