diff --git a/tpl/collections/apply_test.go b/tpl/collections/apply_test.go index 98cb78b51..1afb66808 100644 --- a/tpl/collections/apply_test.go +++ b/tpl/collections/apply_test.go @@ -14,6 +14,7 @@ package collections import ( + "context" "fmt" "io" "reflect" @@ -51,6 +52,10 @@ func (templateFinder) Execute(t tpl.Template, wr io.Writer, data interface{}) er return nil } +func (templateFinder) ExecuteWithContext(ctx context.Context, t tpl.Template, wr io.Writer, data interface{}) error { + return nil +} + func (templateFinder) GetFunc(name string) (reflect.Value, bool) { if name == "dobedobedo" { return reflect.Value{}, false diff --git a/tpl/internal/go_templates/texttemplate/hugo_template.go b/tpl/internal/go_templates/texttemplate/hugo_template.go index eed546e61..b59a98219 100644 --- a/tpl/internal/go_templates/texttemplate/hugo_template.go +++ b/tpl/internal/go_templates/texttemplate/hugo_template.go @@ -1,4 +1,4 @@ -// Copyright 2019 The Hugo Authors. All rights reserved. +// Copyright 2022 The Hugo Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ package template import ( + "context" "io" "reflect" @@ -39,14 +40,15 @@ type Preparer interface { // ExecHelper allows some custom eval hooks. type ExecHelper interface { - GetFunc(tmpl Preparer, name string) (reflect.Value, bool) - GetMethod(tmpl Preparer, receiver reflect.Value, name string) (method reflect.Value, firstArg reflect.Value) - GetMapValue(tmpl Preparer, receiver, key reflect.Value) (reflect.Value, bool) + Init(ctx context.Context, tmpl Preparer) + GetFunc(ctx context.Context, tmpl Preparer, name string) (reflect.Value, reflect.Value, bool) + GetMethod(ctx context.Context, tmpl Preparer, receiver reflect.Value, name string) (method reflect.Value, firstArg reflect.Value) + GetMapValue(ctx context.Context, tmpl Preparer, receiver, key reflect.Value) (reflect.Value, bool) } // Executer executes a given template. type Executer interface { - Execute(p Preparer, wr io.Writer, data interface{}) error + ExecuteWithContext(ctx context.Context, p Preparer, wr io.Writer, data interface{}) error } type executer struct { @@ -57,6 +59,48 @@ func NewExecuter(helper ExecHelper) Executer { return &executer{helper: helper} } +type ( + dataContextKeyType string + hasLockContextKeyType string +) + +const ( + // The data object passed to Execute or ExecuteWithContext gets stored with this key if not already set. + DataContextKey = dataContextKeyType("data") + // Used in partialCached to signal to nested templates that a lock is already taken. + HasLockContextKey = hasLockContextKeyType("hasLock") +) + +// Note: The context is currently not fully implemeted in Hugo. This is a work in progress. +func (t *executer) ExecuteWithContext(ctx context.Context, p Preparer, wr io.Writer, data interface{}) error { + tmpl, err := p.Prepare() + if err != nil { + return err + } + + if v := ctx.Value(DataContextKey); v == nil { + ctx = context.WithValue(ctx, DataContextKey, data) + } + + value, ok := data.(reflect.Value) + if !ok { + value = reflect.ValueOf(data) + } + + state := &state{ + ctx: ctx, + helper: t.helper, + prep: p, + tmpl: tmpl, + wr: wr, + vars: []variable{{"$", value}}, + } + + t.helper.Init(ctx, p) + + return tmpl.executeWithState(state, value) +} + func (t *executer) Execute(p Preparer, wr io.Writer, data interface{}) error { tmpl, err := p.Prepare() if err != nil { @@ -77,7 +121,6 @@ func (t *executer) Execute(p Preparer, wr io.Writer, data interface{}) error { } return tmpl.executeWithState(state, value) - } // Prepare returns a template ready for execution. @@ -101,8 +144,9 @@ func (t *Template) executeWithState(state *state, value reflect.Value) (err erro // can execute in parallel. type state struct { tmpl *Template - prep Preparer // Added for Hugo. - helper ExecHelper // Added for Hugo. + ctx context.Context // Added for Hugo. The orignal data context. + prep Preparer // Added for Hugo. + helper ExecHelper // Added for Hugo. wr io.Writer node parse.Node // current node, for errors vars []variable // push-down stack of variable values. @@ -114,10 +158,11 @@ func (s *state) evalFunction(dot reflect.Value, node *parse.IdentifierNode, cmd name := node.Ident var function reflect.Value + // Added for Hugo. + var first reflect.Value var ok bool if s.helper != nil { - // Added for Hugo. - function, ok = s.helper.GetFunc(s.prep, name) + function, first, ok = s.helper.GetFunc(s.ctx, s.prep, name) } if !ok { @@ -127,6 +172,9 @@ func (s *state) evalFunction(dot reflect.Value, node *parse.IdentifierNode, cmd if !ok { s.errorf("%q is not a defined function", name) } + if first != zero { + return s.evalCall(dot, function, cmd, name, args, final, first) + } return s.evalCall(dot, function, cmd, name, args, final) } @@ -159,7 +207,7 @@ func (s *state) evalField(dot reflect.Value, fieldName string, node parse.Node, var first reflect.Value var method reflect.Value if s.helper != nil { - method, first = s.helper.GetMethod(s.prep, ptr, fieldName) + method, first = s.helper.GetMethod(s.ctx, s.prep, ptr, fieldName) } else { method = ptr.MethodByName(fieldName) } @@ -198,7 +246,7 @@ func (s *state) evalField(dot reflect.Value, fieldName string, node parse.Node, var result reflect.Value if s.helper != nil { // Added for Hugo. - result, _ = s.helper.GetMapValue(s.prep, receiver, nameVal) + result, _ = s.helper.GetMapValue(s.ctx, s.prep, receiver, nameVal) } else { result = receiver.MapIndex(nameVal) } diff --git a/tpl/internal/go_templates/texttemplate/hugo_template_test.go b/tpl/internal/go_templates/texttemplate/hugo_template_test.go index 98a2575eb..150802bf4 100644 --- a/tpl/internal/go_templates/texttemplate/hugo_template_test.go +++ b/tpl/internal/go_templates/texttemplate/hugo_template_test.go @@ -1,4 +1,4 @@ -// Copyright 2019 The Hugo Authors. All rights reserved. +// Copyright 2022 The Hugo Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ package template import ( "bytes" + "context" "reflect" "strings" "testing" @@ -35,24 +36,26 @@ func (t TestStruct) Hello2(arg1, arg2 string) string { return arg1 + " " + arg2 } -type execHelper struct { +type execHelper struct{} + +func (e *execHelper) Init(ctx context.Context, tmpl Preparer) { } -func (e *execHelper) GetFunc(tmpl Preparer, name string) (reflect.Value, bool) { +func (e *execHelper) GetFunc(ctx context.Context, tmpl Preparer, name string) (reflect.Value, reflect.Value, bool) { if name == "print" { - return zero, false + return zero, zero, false } return reflect.ValueOf(func(s string) string { return "hello " + s - }), true + }), zero, true } -func (e *execHelper) GetMapValue(tmpl Preparer, m, key reflect.Value) (reflect.Value, bool) { +func (e *execHelper) GetMapValue(ctx context.Context, tmpl Preparer, m, key reflect.Value) (reflect.Value, bool) { key = reflect.ValueOf(strings.ToLower(key.String())) return m.MapIndex(key), true } -func (e *execHelper) GetMethod(tmpl Preparer, receiver reflect.Value, name string) (method reflect.Value, firstArg reflect.Value) { +func (e *execHelper) GetMethod(ctx context.Context, tmpl Preparer, receiver reflect.Value, name string) (method reflect.Value, firstArg reflect.Value) { if name != "Hello1" { return zero, zero } @@ -78,12 +81,11 @@ Method: {{ .Hello1 "v1" }} var b bytes.Buffer data := TestStruct{S: "sv", M: map[string]string{"a": "av"}} - c.Assert(ex.Execute(templ, &b, data), qt.IsNil) + c.Assert(ex.ExecuteWithContext(context.Background(), templ, &b, data), qt.IsNil) got := b.String() c.Assert(got, qt.Contains, "foo") c.Assert(got, qt.Contains, "hello hugo") c.Assert(got, qt.Contains, "Map: av") c.Assert(got, qt.Contains, "Method: v2 v1") - } diff --git a/tpl/partials/integration_test.go b/tpl/partials/integration_test.go index 5b6c18598..446e47118 100644 --- a/tpl/partials/integration_test.go +++ b/tpl/partials/integration_test.go @@ -75,6 +75,34 @@ partialCached: foo `) } +// Issue 9519 +func TestIncludeCachedRecursion(t *testing.T) { + t.Parallel() + + files := ` +-- config.toml -- +baseURL = 'http://example.com/' +-- layouts/index.html -- +{{ partials.IncludeCached "p1.html" . }} +-- layouts/partials/p1.html -- +{{ partials.IncludeCached "p2.html" . }} +-- layouts/partials/p2.html -- +P2 + + ` + + b := hugolib.NewIntegrationTestBuilder( + hugolib.IntegrationTestConfig{ + T: t, + TxtarString: files, + }, + ).Build() + + b.AssertFileContent("public/index.html", ` +P2 +`) +} + func TestIncludeCacheHints(t *testing.T) { t.Parallel() diff --git a/tpl/partials/partials.go b/tpl/partials/partials.go index 787b49ed3..500f5d1a3 100644 --- a/tpl/partials/partials.go +++ b/tpl/partials/partials.go @@ -16,6 +16,7 @@ package partials import ( + "context" "errors" "fmt" "html/template" @@ -100,8 +101,9 @@ func (c *contextWrapper) Set(in interface{}) string { // If the partial contains a return statement, that value will be returned. // Else, the rendered output will be returned: // A string if the partial is a text/template, or template.HTML when html/template. -func (ns *Namespace) Include(name string, contextList ...interface{}) (interface{}, error) { - name, result, err := ns.include(name, contextList...) +// Note that ctx is provided by Hugo, not the end user. +func (ns *Namespace) Include(ctx context.Context, name string, contextList ...interface{}) (interface{}, error) { + name, result, err := ns.include(ctx, name, contextList...) if err != nil { return result, err } @@ -115,10 +117,10 @@ func (ns *Namespace) Include(name string, contextList ...interface{}) (interface // include is a helper function that lookups and executes the named partial. // Returns the final template name and the rendered output. -func (ns *Namespace) include(name string, contextList ...interface{}) (string, interface{}, error) { - var context interface{} - if len(contextList) > 0 { - context = contextList[0] +func (ns *Namespace) include(ctx context.Context, name string, dataList ...interface{}) (string, interface{}, error) { + var data interface{} + if len(dataList) > 0 { + data = dataList[0] } var n string @@ -149,8 +151,8 @@ func (ns *Namespace) include(name string, contextList ...interface{}) (string, i // Wrap the context sent to the template to capture the return value. // Note that the template is rewritten to make sure that the dot (".") // and the $ variable points to Arg. - context = &contextWrapper{ - Arg: context, + data = &contextWrapper{ + Arg: data, } // We don't care about any template output. @@ -161,13 +163,13 @@ func (ns *Namespace) include(name string, contextList ...interface{}) (string, i w = b } - if err := ns.deps.Tmpl().Execute(templ, w, context); err != nil { - return "", "", err + if err := ns.deps.Tmpl().ExecuteWithContext(ctx, templ, w, data); err != nil { + return "", nil, err } var result interface{} - if ctx, ok := context.(*contextWrapper); ok { + if ctx, ok := data.(*contextWrapper); ok { result = ctx.Result } else if _, ok := templ.(*texttemplate.Template); ok { result = w.(fmt.Stringer).String() @@ -179,17 +181,18 @@ func (ns *Namespace) include(name string, contextList ...interface{}) (string, i } // IncludeCached executes and caches partial templates. The cache is created with name+variants as the key. -func (ns *Namespace) IncludeCached(name string, context interface{}, variants ...interface{}) (interface{}, error) { +// Note that ctx is provided by Hugo, not the end user. +func (ns *Namespace) IncludeCached(ctx context.Context, name string, context interface{}, variants ...interface{}) (interface{}, error) { key, err := createKey(name, variants...) if err != nil { return nil, err } - result, err := ns.getOrCreate(key, context) + result, err := ns.getOrCreate(ctx, key, context) if err == errUnHashable { // Try one more key.variant = helpers.HashString(key.variant) - result, err = ns.getOrCreate(key, context) + result, err = ns.getOrCreate(ctx, key, context) } return result, err @@ -218,7 +221,7 @@ func createKey(name string, variants ...interface{}) (partialCacheKey, error) { var errUnHashable = errors.New("unhashable") -func (ns *Namespace) getOrCreate(key partialCacheKey, context interface{}) (result interface{}, err error) { +func (ns *Namespace) getOrCreate(ctx context.Context, key partialCacheKey, context interface{}) (result interface{}, err error) { start := time.Now() defer func() { if r := recover(); r != nil { @@ -230,9 +233,16 @@ func (ns *Namespace) getOrCreate(key partialCacheKey, context interface{}) (resu } }() - ns.cachedPartials.RLock() + // We may already have a write lock. + hasLock := tpl.GetHasLockFromContext(ctx) + + if !hasLock { + ns.cachedPartials.RLock() + } p, ok := ns.cachedPartials.p[key] - ns.cachedPartials.RUnlock() + if !hasLock { + ns.cachedPartials.RUnlock() + } if ok { if ns.deps.Metrics != nil { @@ -246,11 +256,14 @@ func (ns *Namespace) getOrCreate(key partialCacheKey, context interface{}) (resu return p, nil } - ns.cachedPartials.Lock() - defer ns.cachedPartials.Unlock() + if !hasLock { + ns.cachedPartials.Lock() + defer ns.cachedPartials.Unlock() + ctx = tpl.SetHasLockInContext(ctx, true) + } var name string - name, p, err = ns.include(key.name, context) + name, p, err = ns.include(ctx, key.name, context) if err != nil { return nil, err } diff --git a/tpl/template.go b/tpl/template.go index c5a6a44c0..1d8c98ded 100644 --- a/tpl/template.go +++ b/tpl/template.go @@ -14,6 +14,7 @@ package tpl import ( + "context" "io" "reflect" "regexp" @@ -53,6 +54,7 @@ type UnusedTemplatesProvider interface { type TemplateHandler interface { TemplateFinder Execute(t Template, wr io.Writer, data interface{}) error + ExecuteWithContext(ctx context.Context, t Template, wr io.Writer, data interface{}) error LookupLayout(d output.LayoutDescriptor, f output.Format) (Template, bool, error) HasTemplate(name string) bool } @@ -144,3 +146,20 @@ func extractBaseOf(err string) string { type TemplateFuncGetter interface { GetFunc(name string) (reflect.Value, bool) } + +// GetDataFromContext returns the template data context (usually .Page) from ctx if set. +// NOte: This is not fully implemented yet. +func GetDataFromContext(ctx context.Context) interface{} { + return ctx.Value(texttemplate.DataContextKey) +} + +func GetHasLockFromContext(ctx context.Context) bool { + if v := ctx.Value(texttemplate.HasLockContextKey); v != nil { + return v.(bool) + } + return false +} + +func SetHasLockInContext(ctx context.Context, hasLock bool) context.Context { + return context.WithValue(ctx, texttemplate.HasLockContextKey, hasLock) +} diff --git a/tpl/tplimpl/template.go b/tpl/tplimpl/template.go index 80e350f11..44b486404 100644 --- a/tpl/tplimpl/template.go +++ b/tpl/tplimpl/template.go @@ -15,6 +15,7 @@ package tplimpl import ( "bytes" + "context" "embed" "io" "io/fs" @@ -225,6 +226,10 @@ func (t templateExec) Clone(d *deps.Deps) *templateExec { } func (t *templateExec) Execute(templ tpl.Template, wr io.Writer, data interface{}) error { + return t.ExecuteWithContext(context.Background(), templ, wr, data) +} + +func (t *templateExec) ExecuteWithContext(ctx context.Context, templ tpl.Template, wr io.Writer, data interface{}) error { if rlocker, ok := templ.(types.RLocker); ok { rlocker.RLock() defer rlocker.RUnlock() @@ -249,11 +254,10 @@ func (t *templateExec) Execute(templ tpl.Template, wr io.Writer, data interface{ } } - execErr := t.executor.Execute(templ, wr, data) + execErr := t.executor.ExecuteWithContext(ctx, templ, wr, data) if execErr != nil { execErr = t.addFileContext(templ, execErr) } - return execErr } diff --git a/tpl/tplimpl/template_funcs.go b/tpl/tplimpl/template_funcs.go index 4b3abaada..831b846d0 100644 --- a/tpl/tplimpl/template_funcs.go +++ b/tpl/tplimpl/template_funcs.go @@ -16,6 +16,7 @@ package tplimpl import ( + "context" "reflect" "strings" @@ -61,8 +62,9 @@ import ( ) var ( - _ texttemplate.ExecHelper = (*templateExecHelper)(nil) - zero reflect.Value + _ texttemplate.ExecHelper = (*templateExecHelper)(nil) + zero reflect.Value + contextInterface = reflect.TypeOf((*context.Context)(nil)).Elem() ) type templateExecHelper struct { @@ -70,14 +72,27 @@ type templateExecHelper struct { funcs map[string]reflect.Value } -func (t *templateExecHelper) GetFunc(tmpl texttemplate.Preparer, name string) (reflect.Value, bool) { +func (t *templateExecHelper) GetFunc(ctx context.Context, tmpl texttemplate.Preparer, name string) (fn reflect.Value, firstArg reflect.Value, found bool) { if fn, found := t.funcs[name]; found { - return fn, true + if fn.Type().NumIn() > 0 { + first := fn.Type().In(0) + if first.Implements(contextInterface) { + // TODO(bep) check if we can void this conversion every time -- and if that matters. + // The first argument may be context.Context. This is never provided by the end user, but it's used to pass down + // contextual information, e.g. the top level data context (e.g. Page). + return fn, reflect.ValueOf(ctx), true + } + } + + return fn, zero, true } - return zero, false + return zero, zero, false } -func (t *templateExecHelper) GetMapValue(tmpl texttemplate.Preparer, receiver, key reflect.Value) (reflect.Value, bool) { +func (t *templateExecHelper) Init(ctx context.Context, tmpl texttemplate.Preparer) { +} + +func (t *templateExecHelper) GetMapValue(ctx context.Context, tmpl texttemplate.Preparer, receiver, key reflect.Value) (reflect.Value, bool) { if params, ok := receiver.Interface().(maps.Params); ok { // Case insensitive. keystr := strings.ToLower(key.String()) @@ -93,10 +108,11 @@ func (t *templateExecHelper) GetMapValue(tmpl texttemplate.Preparer, receiver, k return v, v.IsValid() } -func (t *templateExecHelper) GetMethod(tmpl texttemplate.Preparer, receiver reflect.Value, name string) (method reflect.Value, firstArg reflect.Value) { +func (t *templateExecHelper) GetMethod(ctx context.Context, tmpl texttemplate.Preparer, receiver reflect.Value, name string) (method reflect.Value, firstArg reflect.Value) { if t.running { // This is a hot path and receiver.MethodByName really shows up in the benchmarks, // so we maintain a list of method names with that signature. + // TODO(bep) I have a branch that makes this construct superflous. switch name { case "GetPage", "Render": if info, ok := tmpl.(tpl.Info); ok { @@ -107,7 +123,21 @@ func (t *templateExecHelper) GetMethod(tmpl texttemplate.Preparer, receiver refl } } - return receiver.MethodByName(name), zero + fn := receiver.MethodByName(name) + if !fn.IsValid() { + return zero, zero + } + + if fn.Type().NumIn() > 0 { + first := fn.Type().In(0) + if first.Implements(contextInterface) { + // The first argument may be context.Context. This is never provided by the end user, but it's used to pass down + // contextual information, e.g. the top level data context (e.g. Page). + return fn, reflect.ValueOf(ctx) + } + } + + return fn, zero } func newTemplateExecuter(d *deps.Deps) (texttemplate.Executer, map[string]reflect.Value) { diff --git a/tpl/tplimpl/template_funcs_test.go b/tpl/tplimpl/template_funcs_test.go index 711d1350d..6d2587bf7 100644 --- a/tpl/tplimpl/template_funcs_test.go +++ b/tpl/tplimpl/template_funcs_test.go @@ -15,6 +15,7 @@ package tplimpl import ( "bytes" + "context" "fmt" "path/filepath" "reflect" @@ -145,8 +146,7 @@ func TestPartialCached(t *testing.T) { partial := `Now: {{ now.UnixNano }}` name := "testing" - var data struct { - } + var data struct{} v := newTestConfig() @@ -168,19 +168,19 @@ func TestPartialCached(t *testing.T) { ns := partials.New(de) - res1, err := ns.IncludeCached(name, &data) + res1, err := ns.IncludeCached(context.Background(), name, &data) c.Assert(err, qt.IsNil) for j := 0; j < 10; j++ { time.Sleep(2 * time.Nanosecond) - res2, err := ns.IncludeCached(name, &data) + res2, err := ns.IncludeCached(context.Background(), name, &data) c.Assert(err, qt.IsNil) if !reflect.DeepEqual(res1, res2) { t.Fatalf("cache mismatch") } - res3, err := ns.IncludeCached(name, &data, fmt.Sprintf("variant%d", j)) + res3, err := ns.IncludeCached(context.Background(), name, &data, fmt.Sprintf("variant%d", j)) c.Assert(err, qt.IsNil) if reflect.DeepEqual(res1, res3) { @@ -191,14 +191,14 @@ func TestPartialCached(t *testing.T) { func BenchmarkPartial(b *testing.B) { doBenchmarkPartial(b, func(ns *partials.Namespace) error { - _, err := ns.Include("bench1") + _, err := ns.Include(context.Background(), "bench1") return err }) } func BenchmarkPartialCached(b *testing.B) { doBenchmarkPartial(b, func(ns *partials.Namespace) error { - _, err := ns.IncludeCached("bench1", nil) + _, err := ns.IncludeCached(context.Background(), "bench1", nil) return err }) }