diff --git a/tpl/data/data.go b/tpl/data/data.go index d383447ac..4cb8b5e78 100644 --- a/tpl/data/data.go +++ b/tpl/data/data.go @@ -58,8 +58,8 @@ type Namespace struct { // The data separator can be a comma, semi-colon, pipe, etc, but only one character. // If you provide multiple parts for the URL they will be joined together to the final URL. // GetCSV returns nil or a slice slice to use in a short code. -func (ns *Namespace) GetCSV(sep string, urlParts ...interface{}) (d [][]string, err error) { - url := joinURL(urlParts) +func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err error) { + url := joinURL(args) cache := ns.cacheGetCSV unmarshal := func(b []byte) (bool, error) { @@ -85,6 +85,15 @@ func (ns *Namespace) GetCSV(sep string, urlParts ...interface{}) (d [][]string, req.Header.Add("Accept", "text/csv") req.Header.Add("Accept", "text/plain") + // Add custom user headers to the get request + finalArg := args[len(args)-1] + + if userHeaders, ok := finalArg.(map[string]interface{}); ok { + for key, val := range userHeaders { + req.Header.Add(key, val.(string)) + } + } + err = ns.getResource(cache, unmarshal, req) if err != nil { ns.deps.Log.(loggers.IgnorableLogger).Errorsf(constants.ErrRemoteGetCSV, "Failed to get CSV resource %q: %s", url, err) @@ -97,9 +106,9 @@ func (ns *Namespace) GetCSV(sep string, urlParts ...interface{}) (d [][]string, // GetJSON expects one or n-parts of a URL to a resource which can either be a local or a remote one. // If you provide multiple parts they will be joined together to the final URL. // GetJSON returns nil or parsed JSON to use in a short code. -func (ns *Namespace) GetJSON(urlParts ...interface{}) (interface{}, error) { +func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) { var v interface{} - url := joinURL(urlParts) + url := joinURL(args) cache := ns.cacheGetJSON req, err := http.NewRequest("GET", url, nil) @@ -118,6 +127,15 @@ func (ns *Namespace) GetJSON(urlParts ...interface{}) (interface{}, error) { req.Header.Add("Accept", "application/json") req.Header.Add("User-Agent", "Hugo Static Site Generator") + // Add custom user headers to the get request + finalArg := args[len(args)-1] + + if userHeaders, ok := finalArg.(map[string]interface{}); ok { + for key, val := range userHeaders { + req.Header.Add(key, val.(string)) + } + } + err = ns.getResource(cache, unmarshal, req) if err != nil { ns.deps.Log.(loggers.IgnorableLogger).Errorsf(constants.ErrRemoteGetJSON, "Failed to get JSON resource %q: %s", url, err) diff --git a/tpl/data/data_test.go b/tpl/data/data_test.go index f9e8621f2..6b62a2b0d 100644 --- a/tpl/data/data_test.go +++ b/tpl/data/data_test.go @@ -119,6 +119,20 @@ func TestGetCSV(t *testing.T) { c.Assert(got, qt.Not(qt.IsNil), msg) c.Assert(got, qt.DeepEquals, test.expect, msg) + // Test user-defined headers as well + gotHeader, _ := ns.GetCSV(test.sep, test.url, map[string]interface{}{"Accept-Charset": "utf-8", "Max-Forwards": "10"}) + + if _, ok := test.expect.(bool); ok { + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) + // c.Assert(err, msg, qt.Not(qt.IsNil)) + c.Assert(got, qt.IsNil) + continue + } + + c.Assert(err, qt.IsNil, msg) + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0) + c.Assert(gotHeader, qt.Not(qt.IsNil), msg) + c.Assert(gotHeader, qt.DeepEquals, test.expect, msg) } } @@ -206,6 +220,19 @@ func TestGetJSON(t *testing.T) { c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0, msg) c.Assert(got, qt.Not(qt.IsNil), msg) c.Assert(got, qt.DeepEquals, test.expect) + + // Test user-defined headers as well + gotHeader, _ := ns.GetJSON(test.url, map[string]interface{}{"Accept-Charset": "utf-8", "Max-Forwards": "10"}) + + if _, ok := test.expect.(bool); ok { + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) + // c.Assert(err, msg, qt.Not(qt.IsNil)) + continue + } + + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0, msg) + c.Assert(gotHeader, qt.Not(qt.IsNil), msg) + c.Assert(gotHeader, qt.DeepEquals, test.expect) } }