446 lines
13 KiB
Go
446 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
// Run tests
|
|
ret := m.Run()
|
|
os.Exit(ret)
|
|
}
|
|
|
|
// Test the general webserver functionalities
|
|
func TestWebserver(t *testing.T) {
|
|
var cf Config
|
|
|
|
cf.Settings.BindAddress = "127.0.0.1:2088"
|
|
cf.Hooks = make([]Hook, 0)
|
|
cf.Hooks = append(cf.Hooks, Hook{Route: "/test1", Name: "test1", Command: ""})
|
|
cf.Hooks = append(cf.Hooks, Hook{Route: "/test2", Name: "test2", Command: ""})
|
|
|
|
listener, err := CreateListener(cf)
|
|
if err != nil {
|
|
t.Fatalf("error creating listener: %s", err)
|
|
return
|
|
}
|
|
server := CreateWebserver(cf)
|
|
mux := http.NewServeMux()
|
|
server.Handler = mux
|
|
if err := RegisterHandlers(cf, mux); err != nil {
|
|
t.Fatalf("error registering handlers: %s", err)
|
|
return
|
|
}
|
|
go server.Serve(listener)
|
|
defer func() {
|
|
if err := server.Shutdown(context.Background()); err != nil {
|
|
t.Fatalf("error while server shutdown: %s", err)
|
|
return
|
|
}
|
|
}()
|
|
|
|
assertStatusCode := func(url string, statusCode int) {
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
t.Fatalf("%s", err)
|
|
return
|
|
}
|
|
if resp.StatusCode != statusCode {
|
|
t.Fatalf("GET / returns status code %d != %d", resp.StatusCode, statusCode)
|
|
}
|
|
}
|
|
|
|
// Check default sites
|
|
assertStatusCode(fmt.Sprintf("http://%s/", cf.Settings.BindAddress), http.StatusOK)
|
|
assertStatusCode(fmt.Sprintf("http://%s/health", cf.Settings.BindAddress), http.StatusOK)
|
|
assertStatusCode(fmt.Sprintf("http://%s/health.json", cf.Settings.BindAddress), http.StatusOK)
|
|
assertStatusCode(fmt.Sprintf("http://%s/robots.txt", cf.Settings.BindAddress), http.StatusOK)
|
|
// Check for a 404 page
|
|
assertStatusCode(fmt.Sprintf("http://%s/404", cf.Settings.BindAddress), http.StatusNotFound)
|
|
assertStatusCode(fmt.Sprintf("http://%s/test3", cf.Settings.BindAddress), http.StatusNotFound)
|
|
// Test registered hooks
|
|
assertStatusCode(fmt.Sprintf("http://%s/test1", cf.Settings.BindAddress), http.StatusOK)
|
|
assertStatusCode(fmt.Sprintf("http://%s/test2", cf.Settings.BindAddress), http.StatusOK)
|
|
}
|
|
|
|
// Tests the TLS functions of the webserver
|
|
func TestTLSWebserver(t *testing.T) {
|
|
var cf Config
|
|
|
|
const TESTPORT = 2089
|
|
|
|
// Test keypairs. testkey1 belongs to the "localhost" host and testkey2 belongs to the "localhost" and "example.com" hosts
|
|
keypairs := make([]TLSKeypairs, 0)
|
|
keypairs = append(keypairs, TLSKeypairs{Keyfile: "testkey1.pem", Certificate: "testcert1.pem"})
|
|
keypairs = append(keypairs, TLSKeypairs{Keyfile: "testkey2.pem", Certificate: "testcert2.pem"})
|
|
|
|
// Generate test certificates
|
|
for i, keypair := range keypairs {
|
|
if fileExists(keypair.Keyfile) {
|
|
t.Fatalf("test key '%s' already exists", keypair.Keyfile)
|
|
return
|
|
}
|
|
if fileExists(keypair.Certificate) {
|
|
t.Fatalf("test certificate '%s' already exists", keypair.Certificate)
|
|
return
|
|
}
|
|
|
|
hostnames := []string{"localhost"}
|
|
if i == 1 {
|
|
hostnames = append(hostnames, "example.com")
|
|
}
|
|
if err := generateKeypair(keypair.Keyfile, keypair.Certificate, hostnames); err != nil {
|
|
t.Fatalf("keypair generation failed: %s\n", err)
|
|
return
|
|
}
|
|
|
|
defer func(keypair TLSKeypairs) {
|
|
os.Remove(keypair.Keyfile)
|
|
os.Remove(keypair.Certificate)
|
|
}(keypair)
|
|
}
|
|
|
|
cf.Settings.BindAddress = fmt.Sprintf("localhost:%d", TESTPORT)
|
|
cf.Settings.TLS.Enabled = true
|
|
cf.Settings.TLS.MinVersion = "1.3"
|
|
cf.Settings.TLS.MaxVersion = "1.3"
|
|
cf.Settings.TLS.Keypairs = keypairs
|
|
cf.Hooks = make([]Hook, 0)
|
|
cf.Hooks = append(cf.Hooks, Hook{Route: "/test1", Name: "test1", Command: "", Hosts: []string{"localhost"}})
|
|
cf.Hooks = append(cf.Hooks, Hook{Route: "/test2", Name: "test2", Command: "", Hosts: []string{"localhost", "example.com"}})
|
|
|
|
// Setup TLS webserver
|
|
listener, err := CreateTLSListener(cf)
|
|
if err != nil {
|
|
t.Fatalf("error creating tls listener: %s", err)
|
|
return
|
|
}
|
|
server := CreateWebserver(cf)
|
|
mux := http.NewServeMux()
|
|
server.Handler = mux
|
|
if err := RegisterHandlers(cf, mux); err != nil {
|
|
t.Fatalf("error registering handlers: %s", err)
|
|
return
|
|
}
|
|
go server.Serve(listener)
|
|
defer func() {
|
|
if err := server.Shutdown(context.Background()); err != nil {
|
|
t.Fatalf("error while server shutdown: %s", err)
|
|
return
|
|
}
|
|
}()
|
|
|
|
// Default page without https should return a 400 error
|
|
resp, err := http.Get(fmt.Sprintf("http://%s/", cf.Settings.BindAddress))
|
|
if err != nil {
|
|
t.Fatalf("%s", err)
|
|
return
|
|
}
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Fatalf("GET / returns status code %d for default page (400 expected)", resp.StatusCode)
|
|
}
|
|
|
|
// Check default page with tls certificates
|
|
certs := make([]tls.Certificate, 0)
|
|
rootCAs, _ := x509.SystemCertPool()
|
|
for i, keypair := range keypairs {
|
|
x509cert, err := readCertificate(keypair.Certificate)
|
|
if err != nil {
|
|
t.Fatalf("error loading certificate %d: %s", i, err)
|
|
return
|
|
}
|
|
raw := make([][]byte, 0)
|
|
raw = append(raw, x509cert.Raw)
|
|
certs = append(certs, tls.Certificate{Certificate: raw})
|
|
rootCAs.AddCert(x509cert)
|
|
}
|
|
|
|
dialer := &net.Dialer{}
|
|
transport := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
Certificates: certs,
|
|
RootCAs: rootCAs,
|
|
},
|
|
// Mock connections to example.com -> 127.0.0.1
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
if strings.Contains(addr, "example.com") {
|
|
addr = strings.ReplaceAll(addr, "example.com", "127.0.0.1")
|
|
}
|
|
return dialer.DialContext(ctx, network, addr)
|
|
},
|
|
}
|
|
client := http.Client{Transport: transport, Timeout: 15 * time.Second}
|
|
|
|
assertStatusCode := func(url string, statusCode int) {
|
|
resp, err = client.Get(url)
|
|
if err != nil {
|
|
t.Fatalf("%s", err)
|
|
return
|
|
}
|
|
if resp.StatusCode != statusCode {
|
|
t.Fatalf("GET / returns status code %d != %d", resp.StatusCode, statusCode)
|
|
}
|
|
}
|
|
fetchBody := func(url string) (string, error) {
|
|
resp, err = client.Get(url)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
body, err := io.ReadAll(resp.Body)
|
|
return string(body), err
|
|
}
|
|
|
|
// Check default page and test hooks
|
|
assertStatusCode(fmt.Sprintf("https://%s/", cf.Settings.BindAddress), http.StatusOK)
|
|
assertStatusCode(fmt.Sprintf("https://%s/test1", cf.Settings.BindAddress), http.StatusOK)
|
|
assertStatusCode(fmt.Sprintf("https://%s/test2", cf.Settings.BindAddress), http.StatusOK)
|
|
assertStatusCode(fmt.Sprintf("https://%s/test404", cf.Settings.BindAddress), http.StatusNotFound)
|
|
|
|
// Check if connection via TLS 1.2 is not accepted (we're enforcing TLS >= 1.3)
|
|
transport.TLSClientConfig.MinVersion = tls.VersionTLS12
|
|
transport.TLSClientConfig.MaxVersion = tls.VersionTLS12
|
|
resp, err = client.Get(fmt.Sprintf("https://%s/", cf.Settings.BindAddress))
|
|
if err == nil {
|
|
t.Fatal("tls 1.2 connection possible where it should be unsupported", err)
|
|
return
|
|
} else {
|
|
// TODO: Matching by string might be flanky.
|
|
if !strings.Contains(err.Error(), "tls: protocol version not supported") {
|
|
t.Fatalf("%s", err)
|
|
return
|
|
}
|
|
}
|
|
transport.TLSClientConfig.MaxVersion = tls.VersionTLS13
|
|
|
|
// Check if example.com resolves (second certificate)
|
|
assertStatusCode(fmt.Sprintf("https://example.com:%d/", TESTPORT), http.StatusOK)
|
|
// Only /test2 should be reachable via example.com
|
|
assertStatusCode(fmt.Sprintf("https://example.com:%d/test1", TESTPORT), http.StatusNotFound)
|
|
assertStatusCode(fmt.Sprintf("https://example.com:%d/test2", TESTPORT), http.StatusOK)
|
|
|
|
// Assert, that the host 404 page is the same as the 404 page for a route that doesn't exist.
|
|
// This check is needed, because we pretend a path to not exist, if `hosts` is configured and
|
|
// we don't want to give attackers the possibility to distinguish between the two 404 errors
|
|
if body1, err := fetchBody(fmt.Sprintf("https://%s/test404", cf.Settings.BindAddress)); err != nil {
|
|
t.Fatalf("%s", err)
|
|
return
|
|
} else if body2, err := fetchBody(fmt.Sprintf("https://example.com:%d/test1", TESTPORT)); err != nil {
|
|
t.Fatalf("%s", err)
|
|
return
|
|
} else {
|
|
if body1 != body2 {
|
|
t.Fatal("404 bodies differ between default 404 page and host-not-matched route", err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Tests the run hook commands
|
|
func TestRunHook(t *testing.T) {
|
|
testText := "hello Test"
|
|
hook := Hook{Name: "hook", Command: "cat"}
|
|
|
|
buffer, err := hook.Run([]byte(testText))
|
|
if err != nil {
|
|
t.Fatalf("running test hook failed: %s", err)
|
|
}
|
|
ret := string(buffer)
|
|
if ret != testText {
|
|
t.Error("returned string mismatch")
|
|
}
|
|
}
|
|
|
|
// Tests passing the request header and body
|
|
func TestHeaderAndBody(t *testing.T) {
|
|
// Create temp file
|
|
tempFile, err := os.CreateTemp("", "test_header_body_*")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer func() {
|
|
os.Remove(tempFile.Name())
|
|
}()
|
|
|
|
// Create test webserver with receive hook, that passes all headers and the body to the temp file
|
|
|
|
var cf Config
|
|
|
|
cf.Settings.BindAddress = "127.0.0.1:2088"
|
|
cf.Settings.MaxBodySize = 4096
|
|
cf.Hooks = make([]Hook, 0)
|
|
cf.Hooks = append(cf.Hooks, Hook{Name: "hook", Command: fmt.Sprintf("tee %s", tempFile.Name()), Route: "/header_and_body"})
|
|
|
|
listener, err := CreateListener(cf)
|
|
if err != nil {
|
|
t.Fatalf("error creating listener: %s", err)
|
|
return
|
|
}
|
|
server := CreateWebserver(cf)
|
|
mux := http.NewServeMux()
|
|
server.Handler = mux
|
|
if err := RegisterHandlers(cf, mux); err != nil {
|
|
t.Fatalf("error registering handlers: %s", err)
|
|
return
|
|
}
|
|
go server.Serve(listener)
|
|
defer func() {
|
|
if err := server.Shutdown(context.Background()); err != nil {
|
|
t.Fatalf("error while server shutdown: %s", err)
|
|
return
|
|
}
|
|
}()
|
|
|
|
// Create http request with custom headers and a message body
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/header_and_body", cf.Settings.BindAddress), nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
headers := make(map[string]string, 0)
|
|
headers["Header1"] = "value1"
|
|
headers["Header2"] = "value2"
|
|
headers["Header3"] = "value3"
|
|
headers["Content-Type"] = "this is the content type"
|
|
for k, v := range headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
req.Body = io.NopCloser(bytes.NewReader([]byte("this is the request body\nit is awesome")))
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("http request error: %s", err)
|
|
}
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("http request failed: %d != %d", res.StatusCode, http.StatusOK)
|
|
}
|
|
|
|
// Assert that the headers and the body is in the test file
|
|
buf, err := os.ReadFile(tempFile.Name())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
contents := string(buf)
|
|
assertHeader := func(key string, value string) {
|
|
if !strings.Contains(contents, key) {
|
|
t.Fatalf("Header %s is not present", key)
|
|
}
|
|
if !strings.Contains(contents, fmt.Sprintf("%s:%s\n", key, value)) {
|
|
t.Fatalf("Header %s has not the right value", key)
|
|
}
|
|
}
|
|
for k, v := range headers {
|
|
assertHeader(k, v)
|
|
}
|
|
|
|
// Assert the message body got passed as well
|
|
|
|
// Assert the messaeg body got cropped
|
|
}
|
|
|
|
func generateKeypair(keyfile string, certfile string, hostnames []string) error {
|
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Write key to file
|
|
var buffer []byte = x509.MarshalPKCS1PrivateKey(key)
|
|
block := &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: buffer,
|
|
}
|
|
file, err := os.Create(keyfile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
if err := file.Chmod(os.FileMode(0400)); err != nil {
|
|
return err
|
|
}
|
|
if err = pem.Encode(file, block); err != nil {
|
|
return err
|
|
}
|
|
if err := file.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Generate certificate
|
|
notBefore := time.Now()
|
|
notAfter := notBefore.Add(365 * 24 * 10 * time.Hour)
|
|
|
|
//Create certificate templet
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(0),
|
|
Subject: pkix.Name{CommonName: hostnames[0]},
|
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
DNSNames: hostnames,
|
|
NotBefore: notBefore,
|
|
NotAfter: notAfter,
|
|
BasicConstraintsValid: true,
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
|
}
|
|
//Create certificate using templet
|
|
cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
block = &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: cert,
|
|
}
|
|
|
|
file, err = os.Create(certfile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
if err := file.Chmod(os.FileMode(0644)); err != nil {
|
|
return err
|
|
}
|
|
if err = pem.Encode(file, block); err != nil {
|
|
return err
|
|
}
|
|
if err := file.Close(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func readCertificate(certfile string) (*x509.Certificate, error) {
|
|
buffer, err := os.ReadFile(certfile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p, _ := pem.Decode(buffer)
|
|
if p == nil {
|
|
return nil, fmt.Errorf("invalid pem file")
|
|
}
|
|
cert, err := x509.ParseCertificate(p.Bytes)
|
|
return cert, err
|
|
|
|
}
|
|
|
|
func fileExists(filename string) bool {
|
|
st, err := os.Stat(filename)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return !st.IsDir()
|
|
}
|