Protect against concurrent Scratch read and write

Fixes #2005
This commit is contained in:
Bjørn Erik Pedersen 2016-03-21 20:42:27 +01:00
parent f52b040ee1
commit 02effd9dc4
2 changed files with 53 additions and 0 deletions

View file

@ -17,11 +17,13 @@ import (
"github.com/spf13/hugo/helpers" "github.com/spf13/hugo/helpers"
"reflect" "reflect"
"sort" "sort"
"sync"
) )
// Scratch is a writable context used for stateful operations in Page/Node rendering. // Scratch is a writable context used for stateful operations in Page/Node rendering.
type Scratch struct { type Scratch struct {
values map[string]interface{} values map[string]interface{}
mu sync.RWMutex
} }
// For single values, Add will add (using the + operator) the addend to the existing addend (if found). // For single values, Add will add (using the + operator) the addend to the existing addend (if found).
@ -29,6 +31,9 @@ type Scratch struct {
// //
// If the first add for a key is an array or slice, then the next value(s) will be appended. // If the first add for a key is an array or slice, then the next value(s) will be appended.
func (c *Scratch) Add(key string, newAddend interface{}) (string, error) { func (c *Scratch) Add(key string, newAddend interface{}) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
var newVal interface{} var newVal interface{}
existingAddend, found := c.values[key] existingAddend, found := c.values[key]
if found { if found {
@ -59,18 +64,27 @@ func (c *Scratch) Add(key string, newAddend interface{}) (string, error) {
// Set stores a value with the given key in the Node context. // Set stores a value with the given key in the Node context.
// This value can later be retrieved with Get. // This value can later be retrieved with Get.
func (c *Scratch) Set(key string, value interface{}) string { func (c *Scratch) Set(key string, value interface{}) string {
c.mu.Lock()
defer c.mu.Unlock()
c.values[key] = value c.values[key] = value
return "" return ""
} }
// Get returns a value previously set by Add or Set // Get returns a value previously set by Add or Set
func (c *Scratch) Get(key string) interface{} { func (c *Scratch) Get(key string) interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return c.values[key] return c.values[key]
} }
// SetInMap stores a value to a map with the given key in the Node context. // SetInMap stores a value to a map with the given key in the Node context.
// This map can later be retrieved with GetSortedMapValues. // This map can later be retrieved with GetSortedMapValues.
func (c *Scratch) SetInMap(key string, mapKey string, value interface{}) string { func (c *Scratch) SetInMap(key string, mapKey string, value interface{}) string {
c.mu.Lock()
defer c.mu.Unlock()
_, found := c.values[key] _, found := c.values[key]
if !found { if !found {
c.values[key] = make(map[string]interface{}) c.values[key] = make(map[string]interface{})
@ -82,6 +96,9 @@ func (c *Scratch) SetInMap(key string, mapKey string, value interface{}) string
// GetSortedMapValues returns a sorted map previously filled with SetInMap // GetSortedMapValues returns a sorted map previously filled with SetInMap
func (c *Scratch) GetSortedMapValues(key string) interface{} { func (c *Scratch) GetSortedMapValues(key string) interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
if c.values[key] == nil { if c.values[key] == nil {
return nil return nil
} }

View file

@ -16,6 +16,7 @@ package hugolib
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"reflect" "reflect"
"sync"
"testing" "testing"
) )
@ -80,6 +81,41 @@ func TestScratchSet(t *testing.T) {
assert.Equal(t, "val", scratch.Get("key")) assert.Equal(t, "val", scratch.Get("key"))
} }
// Issue #2005
func TestScratchInParallel(t *testing.T) {
var wg sync.WaitGroup
scratch := newScratch()
key := "counter"
scratch.Set(key, 1)
for i := 1; i <= 10; i++ {
wg.Add(1)
go func(j int) {
for k := 0; k < 10; k++ {
newVal := k + j
_, err := scratch.Add(key, newVal)
if err != nil {
t.Errorf("Got err %s", err)
}
scratch.Set(key, newVal)
val := scratch.Get(key)
if counter, ok := val.(int); ok {
if counter < 1 {
t.Errorf("Got %d", counter)
}
} else {
t.Errorf("Got %T", val)
}
}
wg.Done()
}(i)
}
wg.Wait()
}
func TestScratchGet(t *testing.T) { func TestScratchGet(t *testing.T) {
scratch := newScratch() scratch := newScratch()
nothing := scratch.Get("nothing") nothing := scratch.Get("nothing")