diff --git a/tpl/template.go b/tpl/template.go index e8cdd4050..a051eba0b 100644 --- a/tpl/template.go +++ b/tpl/template.go @@ -785,20 +785,36 @@ func IsSet(a interface{}, key interface{}) bool { return false } -func ReturnWhenSet(a interface{}, index int) interface{} { - av := reflect.ValueOf(a) +func ReturnWhenSet(a, k interface{}) interface{} { + av, isNil := indirect(reflect.ValueOf(a)) + if isNil { + return "" + } + var avv reflect.Value switch av.Kind() { case reflect.Array, reflect.Slice: - if av.Len() > index { + index, ok := k.(int) + if ok && av.Len() > index { + avv = av.Index(index) + } + case reflect.Map: + kv := reflect.ValueOf(k) + if kv.Type().AssignableTo(av.Type().Key()) { + avv = av.MapIndex(kv) + } + } - avv := av.Index(index) - switch avv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return avv.Int() - case reflect.String: - return avv.String() - } + if avv.IsValid() { + switch avv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return avv.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return avv.Uint() + case reflect.Float32, reflect.Float64: + return avv.Float() + case reflect.String: + return avv.String() } } diff --git a/tpl/template_test.go b/tpl/template_test.go index 578d1d884..98cf2d061 100644 --- a/tpl/template_test.go +++ b/tpl/template_test.go @@ -791,6 +791,31 @@ func TestSort(t *testing.T) { } } +func TestReturnWhenSet(t *testing.T) { + for i, this := range []struct { + data interface{} + key interface{} + expect interface{} + }{ + {[]int{1, 2, 3}, 1, int64(2)}, + {[]uint{1, 2, 3}, 1, uint64(2)}, + {[]float64{1.1, 2.2, 3.3}, 1, float64(2.2)}, + {[]string{"foo", "bar", "baz"}, 1, "bar"}, + {[]TstX{TstX{A: "a", B: "b"}, TstX{A: "c", B: "d"}, TstX{A: "e", B: "f"}}, 1, ""}, + {map[string]int{"foo": 1, "bar": 2, "baz": 3}, "bar", int64(2)}, + {map[string]uint{"foo": 1, "bar": 2, "baz": 3}, "bar", uint64(2)}, + {map[string]float64{"foo": 1.1, "bar": 2.2, "baz": 3.3}, "bar", float64(2.2)}, + {map[string]string{"foo": "FOO", "bar": "BAR", "baz": "BAZ"}, "bar", "BAR"}, + {map[string]TstX{"foo": TstX{A: "a", B: "b"}, "bar": TstX{A: "c", B: "d"}, "baz": TstX{A: "e", B: "f"}}, "bar", ""}, + {(*[]string)(nil), "bar", ""}, + } { + result := ReturnWhenSet(this.data, this.key) + if !reflect.DeepEqual(result, this.expect) { + t.Errorf("[%d] ReturnWhenSet got %v (type %v) but expected %v (type %v)", i, result, reflect.TypeOf(result), this.expect, reflect.TypeOf(this.expect)) + } + } +} + func TestMarkdownify(t *testing.T) { result := Markdownify("Hello **World!**")