diff --git a/helpers/general.go b/helpers/general.go index f2ac253be..32666defa 100644 --- a/helpers/general.go +++ b/helpers/general.go @@ -17,10 +17,12 @@ import ( "bytes" "crypto/md5" "encoding/hex" + "errors" "fmt" "io" "net" "path/filepath" + "reflect" "strings" bp "github.com/spf13/hugo/bufferpool" @@ -118,3 +120,124 @@ func Md5String(f string) string { h.Write([]byte(f)) return hex.EncodeToString(h.Sum([]byte{})) } + +// DoArithmetic performs arithmetic operations (+,-,*,/) using reflection to +// determine the type of the two terms. +func DoArithmetic(a, b interface{}, op rune) (interface{}, error) { + av := reflect.ValueOf(a) + bv := reflect.ValueOf(b) + var ai, bi int64 + var af, bf float64 + var au, bu uint64 + switch av.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ai = av.Int() + switch bv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + bi = bv.Int() + case reflect.Float32, reflect.Float64: + af = float64(ai) // may overflow + ai = 0 + bf = bv.Float() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + bu = bv.Uint() + if ai >= 0 { + au = uint64(ai) + ai = 0 + } else { + bi = int64(bu) // may overflow + bu = 0 + } + default: + return nil, errors.New("Can't apply the operator to the values") + } + case reflect.Float32, reflect.Float64: + af = av.Float() + switch bv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + bf = float64(bv.Int()) // may overflow + case reflect.Float32, reflect.Float64: + bf = bv.Float() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + bf = float64(bv.Uint()) // may overflow + default: + return nil, errors.New("Can't apply the operator to the values") + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + au = av.Uint() + switch bv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + bi = bv.Int() + if bi >= 0 { + bu = uint64(bi) + bi = 0 + } else { + ai = int64(au) // may overflow + au = 0 + } + case reflect.Float32, reflect.Float64: + af = float64(au) // may overflow + au = 0 + bf = bv.Float() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + bu = bv.Uint() + default: + return nil, errors.New("Can't apply the operator to the values") + } + case reflect.String: + as := av.String() + if bv.Kind() == reflect.String && op == '+' { + bs := bv.String() + return as + bs, nil + } else { + return nil, errors.New("Can't apply the operator to the values") + } + default: + return nil, errors.New("Can't apply the operator to the values") + } + + switch op { + case '+': + if ai != 0 || bi != 0 { + return ai + bi, nil + } else if af != 0 || bf != 0 { + return af + bf, nil + } else if au != 0 || bu != 0 { + return au + bu, nil + } else { + return 0, nil + } + case '-': + if ai != 0 || bi != 0 { + return ai - bi, nil + } else if af != 0 || bf != 0 { + return af - bf, nil + } else if au != 0 || bu != 0 { + return au - bu, nil + } else { + return 0, nil + } + case '*': + if ai != 0 || bi != 0 { + return ai * bi, nil + } else if af != 0 || bf != 0 { + return af * bf, nil + } else if au != 0 || bu != 0 { + return au * bu, nil + } else { + return 0, nil + } + case '/': + if bi != 0 { + return ai / bi, nil + } else if bf != 0 { + return af / bf, nil + } else if bu != 0 { + return au / bu, nil + } else { + return nil, errors.New("Can't divide the value by 0") + } + default: + return nil, errors.New("There is no such an operation") + } +} diff --git a/helpers/general_test.go b/helpers/general_test.go index fef073f05..527ba6fac 100644 --- a/helpers/general_test.go +++ b/helpers/general_test.go @@ -2,6 +2,7 @@ package helpers import ( "github.com/stretchr/testify/assert" + "reflect" "strings" "testing" ) @@ -128,3 +129,91 @@ func TestMd5StringEmpty(t *testing.T) { Md5String(in) } } + +func TestDoArithmetic(t *testing.T) { + for i, this := range []struct { + a interface{} + b interface{} + op rune + expect interface{} + }{ + {3, 2, '+', int64(5)}, + {3, 2, '-', int64(1)}, + {3, 2, '*', int64(6)}, + {3, 2, '/', int64(1)}, + {3.0, 2, '+', float64(5)}, + {3.0, 2, '-', float64(1)}, + {3.0, 2, '*', float64(6)}, + {3.0, 2, '/', float64(1.5)}, + {3, 2.0, '+', float64(5)}, + {3, 2.0, '-', float64(1)}, + {3, 2.0, '*', float64(6)}, + {3, 2.0, '/', float64(1.5)}, + {3.0, 2.0, '+', float64(5)}, + {3.0, 2.0, '-', float64(1)}, + {3.0, 2.0, '*', float64(6)}, + {3.0, 2.0, '/', float64(1.5)}, + {uint(3), uint(2), '+', uint64(5)}, + {uint(3), uint(2), '-', uint64(1)}, + {uint(3), uint(2), '*', uint64(6)}, + {uint(3), uint(2), '/', uint64(1)}, + {uint(3), 2, '+', uint64(5)}, + {uint(3), 2, '-', uint64(1)}, + {uint(3), 2, '*', uint64(6)}, + {uint(3), 2, '/', uint64(1)}, + {3, uint(2), '+', uint64(5)}, + {3, uint(2), '-', uint64(1)}, + {3, uint(2), '*', uint64(6)}, + {3, uint(2), '/', uint64(1)}, + {uint(3), -2, '+', int64(1)}, + {uint(3), -2, '-', int64(5)}, + {uint(3), -2, '*', int64(-6)}, + {uint(3), -2, '/', int64(-1)}, + {-3, uint(2), '+', int64(-1)}, + {-3, uint(2), '-', int64(-5)}, + {-3, uint(2), '*', int64(-6)}, + {-3, uint(2), '/', int64(-1)}, + {uint(3), 2.0, '+', float64(5)}, + {uint(3), 2.0, '-', float64(1)}, + {uint(3), 2.0, '*', float64(6)}, + {uint(3), 2.0, '/', float64(1.5)}, + {3.0, uint(2), '+', float64(5)}, + {3.0, uint(2), '-', float64(1)}, + {3.0, uint(2), '*', float64(6)}, + {3.0, uint(2), '/', float64(1.5)}, + {0, 0, '+', 0}, + {0, 0, '-', 0}, + {0, 0, '*', 0}, + {"foo", "bar", '+', "foobar"}, + {3, 0, '/', false}, + {3.0, 0, '/', false}, + {3, 0.0, '/', false}, + {uint(3), uint(0), '/', false}, + {3, uint(0), '/', false}, + {-3, uint(0), '/', false}, + {uint(3), 0, '/', false}, + {3.0, uint(0), '/', false}, + {uint(3), 0.0, '/', false}, + {3, "foo", '+', false}, + {3.0, "foo", '+', false}, + {uint(3), "foo", '+', false}, + {"foo", 3, '+', false}, + {"foo", "bar", '-', false}, + {3, 2, '%', false}, + } { + result, err := DoArithmetic(this.a, this.b, this.op) + if b, ok := this.expect.(bool); ok && !b { + if err == nil { + t.Errorf("[%d] doArithmetic didn't return an expected error") + } + } else { + if err != nil { + t.Errorf("[%d] failed: %s", i, err) + continue + } + if !reflect.DeepEqual(result, this.expect) { + t.Errorf("[%d] doArithmetic got %v but expected %v", i, result, this.expect) + } + } + } +} diff --git a/hugolib/node.go b/hugolib/node.go index 1916e8b03..604b5475a 100644 --- a/hugolib/node.go +++ b/hugolib/node.go @@ -33,6 +33,7 @@ type Node struct { UrlPath paginator *pager paginatorInit sync.Once + scratch *Scratch } func (n *Node) Now() time.Time { @@ -124,3 +125,11 @@ type UrlPath struct { Slug string Section string } + +// Scratch returns the writable context associated with this Node. +func (n *Node) Scratch() *Scratch { + if n.scratch == nil { + n.scratch = newScratch() + } + return n.scratch +} diff --git a/hugolib/scratch.go b/hugolib/scratch.go new file mode 100644 index 000000000..0f5c4b484 --- /dev/null +++ b/hugolib/scratch.go @@ -0,0 +1,57 @@ +// Copyright © 2013-14 Steve Francia . +// +// Licensed under the Simple Public License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://opensource.org/licenses/Simple-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hugolib + +import ( + "github.com/spf13/hugo/helpers" +) + +// Scratch is a writable context used for stateful operations in Page/Node rendering. +type Scratch struct { + values map[string]interface{} +} + +// Add will add (using the + operator) the addend to the existing addend (if found). +// Supports numeric values and strings. +func (c *Scratch) Add(key string, newAddend interface{}) (string, error) { + var newVal interface{} + existingAddend, found := c.values[key] + if found { + var err error + newVal, err = helpers.DoArithmetic(existingAddend, newAddend, '+') + if err != nil { + return "", err + } + } else { + newVal = newAddend + } + c.values[key] = newVal + return "", nil // have to return something to make it work with the Go templates +} + +// Set stores a value with the given key in the Node context. +// This value can later be retrieved with Get. +func (c *Scratch) Set(key string, value interface{}) string { + c.values[key] = value + return "" +} + +// Get returns a value previously set by Add or Set +func (c *Scratch) Get(key string) interface{} { + return c.values[key] +} + +func newScratch() *Scratch { + return &Scratch{values: make(map[string]interface{})} +} diff --git a/hugolib/scratch_test.go b/hugolib/scratch_test.go new file mode 100644 index 000000000..adff2c8a8 --- /dev/null +++ b/hugolib/scratch_test.go @@ -0,0 +1,49 @@ +package hugolib + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestScratchAdd(t *testing.T) { + scratch := newScratch() + scratch.Add("int1", 10) + scratch.Add("int1", 20) + scratch.Add("int2", 20) + + assert.Equal(t, 30, scratch.Get("int1")) + assert.Equal(t, 20, scratch.Get("int2")) + + scratch.Add("float1", float64(10.5)) + scratch.Add("float1", float64(20.1)) + + assert.Equal(t, float64(30.6), scratch.Get("float1")) + + scratch.Add("string1", "Hello ") + scratch.Add("string1", "big ") + scratch.Add("string1", "World!") + + assert.Equal(t, "Hello big World!", scratch.Get("string1")) + + scratch.Add("scratch", scratch) + _, err := scratch.Add("scratch", scratch) + + if err == nil { + t.Errorf("Expected error from invalid arithmetic") + } + +} + +func TestScratchSet(t *testing.T) { + scratch := newScratch() + scratch.Set("key", "val") + assert.Equal(t, "val", scratch.Get("key")) +} + +func TestScratchGet(t *testing.T) { + scratch := newScratch() + nothing := scratch.Get("nothing") + if nothing != nil { + t.Errorf("Should not return anything, but got %v", nothing) + } +} diff --git a/tpl/template_test.go b/tpl/template_test.go index 4477d0d26..6fe1c9328 100644 --- a/tpl/template_test.go +++ b/tpl/template_test.go @@ -101,94 +101,6 @@ func doTestCompare(t *testing.T, tp tstCompareType, funcUnderTest func(a, b inte } } -func TestDoArithmetic(t *testing.T) { - for i, this := range []struct { - a interface{} - b interface{} - op rune - expect interface{} - }{ - {3, 2, '+', int64(5)}, - {3, 2, '-', int64(1)}, - {3, 2, '*', int64(6)}, - {3, 2, '/', int64(1)}, - {3.0, 2, '+', float64(5)}, - {3.0, 2, '-', float64(1)}, - {3.0, 2, '*', float64(6)}, - {3.0, 2, '/', float64(1.5)}, - {3, 2.0, '+', float64(5)}, - {3, 2.0, '-', float64(1)}, - {3, 2.0, '*', float64(6)}, - {3, 2.0, '/', float64(1.5)}, - {3.0, 2.0, '+', float64(5)}, - {3.0, 2.0, '-', float64(1)}, - {3.0, 2.0, '*', float64(6)}, - {3.0, 2.0, '/', float64(1.5)}, - {uint(3), uint(2), '+', uint64(5)}, - {uint(3), uint(2), '-', uint64(1)}, - {uint(3), uint(2), '*', uint64(6)}, - {uint(3), uint(2), '/', uint64(1)}, - {uint(3), 2, '+', uint64(5)}, - {uint(3), 2, '-', uint64(1)}, - {uint(3), 2, '*', uint64(6)}, - {uint(3), 2, '/', uint64(1)}, - {3, uint(2), '+', uint64(5)}, - {3, uint(2), '-', uint64(1)}, - {3, uint(2), '*', uint64(6)}, - {3, uint(2), '/', uint64(1)}, - {uint(3), -2, '+', int64(1)}, - {uint(3), -2, '-', int64(5)}, - {uint(3), -2, '*', int64(-6)}, - {uint(3), -2, '/', int64(-1)}, - {-3, uint(2), '+', int64(-1)}, - {-3, uint(2), '-', int64(-5)}, - {-3, uint(2), '*', int64(-6)}, - {-3, uint(2), '/', int64(-1)}, - {uint(3), 2.0, '+', float64(5)}, - {uint(3), 2.0, '-', float64(1)}, - {uint(3), 2.0, '*', float64(6)}, - {uint(3), 2.0, '/', float64(1.5)}, - {3.0, uint(2), '+', float64(5)}, - {3.0, uint(2), '-', float64(1)}, - {3.0, uint(2), '*', float64(6)}, - {3.0, uint(2), '/', float64(1.5)}, - {0, 0, '+', 0}, - {0, 0, '-', 0}, - {0, 0, '*', 0}, - {"foo", "bar", '+', "foobar"}, - {3, 0, '/', false}, - {3.0, 0, '/', false}, - {3, 0.0, '/', false}, - {uint(3), uint(0), '/', false}, - {3, uint(0), '/', false}, - {-3, uint(0), '/', false}, - {uint(3), 0, '/', false}, - {3.0, uint(0), '/', false}, - {uint(3), 0.0, '/', false}, - {3, "foo", '+', false}, - {3.0, "foo", '+', false}, - {uint(3), "foo", '+', false}, - {"foo", 3, '+', false}, - {"foo", "bar", '-', false}, - {3, 2, '%', false}, - } { - result, err := doArithmetic(this.a, this.b, this.op) - if b, ok := this.expect.(bool); ok && !b { - if err == nil { - t.Errorf("[%d] doArithmetic didn't return an expected error", i) - } - } else { - if err != nil { - t.Errorf("[%d] failed: %s", i, err) - continue - } - if !reflect.DeepEqual(result, this.expect) { - t.Errorf("[%d] doArithmetic got %v but expected %v", i, result, this.expect) - } - } - } -} - func TestMod(t *testing.T) { for i, this := range []struct { a interface{}