From 302a6ac701bc2599d4ca98c7d263d364505a3980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antti=20J=C3=A4rvinen?= Date: Sun, 6 Dec 2015 12:28:03 +0200 Subject: [PATCH] Add Random function to template functions Adds Random function to pick N random items from sequence. --- tpl/template_funcs.go | 49 ++++++++++++++++++++++++++++++++++++++ tpl/template_funcs_test.go | 37 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/tpl/template_funcs.go b/tpl/template_funcs.go index e7166ea1c..e933c0bb3 100644 --- a/tpl/template_funcs.go +++ b/tpl/template_funcs.go @@ -20,6 +20,7 @@ import ( "fmt" "html" "html/template" + "math/rand" "os" "reflect" "sort" @@ -495,6 +496,53 @@ func After(index interface{}, seq interface{}) (interface{}, error) { return seqv.Slice(indexv, seqv.Len()).Interface(), nil } +// Random is exposed to templates, to iterate over N random items in a +// rangeable list. +func Random(count interface{}, seq interface{}) (interface{}, error) { + + if count == nil || seq == nil { + return nil, errors.New("both count and seq must be provided") + } + + countv, err := cast.ToIntE(count) + + if err != nil { + return nil, err + } + + if countv < 1 { + return nil, errors.New("can't return negative/empty count of items from sequence") + } + + seqv := reflect.ValueOf(seq) + seqv, isNil := indirect(seqv) + if isNil { + return nil, errors.New("can't iterate over a nil value") + } + + switch seqv.Kind() { + case reflect.Array, reflect.Slice, reflect.String: + // okay + default: + return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String()) + } + + if countv >= seqv.Len() { + countv = seqv.Len() + } + + suffled := reflect.MakeSlice(reflect.TypeOf(seq), seqv.Len(), seqv.Len()) + + rand.Seed(time.Now().UTC().UnixNano()) + randomIndices := rand.Perm(seqv.Len()) + + for index, value := range randomIndices { + suffled.Index(value).Set(seqv.Index(index)) + } + + return suffled.Slice(0, countv).Interface(), nil +} + var ( zero reflect.Value errorType = reflect.TypeOf((*error)(nil)).Elem() @@ -1453,6 +1501,7 @@ func init() { "first": First, "last": Last, "after": After, + "random": Random, "where": Where, "delimit": Delimit, "sort": Sort, diff --git a/tpl/template_funcs_test.go b/tpl/template_funcs_test.go index 761459b32..91d3cbd1c 100644 --- a/tpl/template_funcs_test.go +++ b/tpl/template_funcs_test.go @@ -341,6 +341,43 @@ func TestAfter(t *testing.T) { } } +func TestRandom(t *testing.T) { + for i, this := range []struct { + count interface{} + sequence interface{} + expect interface{} + }{ + {int(2), []string{"a", "b", "c", "d"}, 2}, + {int64(2), []int{100, 200, 300}, 2}, + {"1", []int{100, 200, 300}, 1}, + {100, []int{100, 200}, 2}, + {int32(3), []string{"a", "b"}, 2}, + {int64(-1), []int{100, 200, 300}, false}, + {"noint", []int{100, 200, 300}, false}, + {1, nil, false}, + {nil, []int{100}, false}, + {1, t, false}, + } { + results, err := Random(this.count, this.sequence) + if b, ok := this.expect.(bool); ok && !b { + if err == nil { + t.Errorf("[%d] First didn't return an expected error", i) + } + } else { + resultsv := reflect.ValueOf(results) + if err != nil { + t.Errorf("[%d] failed: %s", i, err) + continue + } + + if resultsv.Len() != this.expect { + t.Errorf("[%d] requested %d random items, got %v but expected %v", + i, this.count, resultsv.Len(), this.expect) + } + } + } +} + func TestDictionary(t *testing.T) { for i, this := range []struct { v1 []interface{}