Compare commits
2 commits
d249964e26
...
296cc5d133
Author | SHA1 | Date | |
---|---|---|---|
Felix Niederwanger | 296cc5d133 | ||
d9d8b73b50 |
|
@ -2,10 +2,18 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -36,6 +44,12 @@ func TestWebserver(t *testing.T) {
|
|||
return
|
||||
}
|
||||
go server.Serve(listener)
|
||||
defer func() {
|
||||
if err := server.Shutdown(context.Background()); err != nil {
|
||||
t.Fatalf("error while server shutdown: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
assertStatusCode := func(url string, statusCode int) {
|
||||
resp, err := http.Get(url)
|
||||
|
@ -59,9 +73,203 @@ func TestWebserver(t *testing.T) {
|
|||
// Test registered hooks
|
||||
assertStatusCode(fmt.Sprintf("http://%s/test1", cf.Settings.BindAddress), http.StatusOK)
|
||||
assertStatusCode(fmt.Sprintf("http://%s/test2", cf.Settings.BindAddress), http.StatusOK)
|
||||
}
|
||||
|
||||
if err := server.Shutdown(context.Background()); err != nil {
|
||||
t.Fatalf("error while server shutdown: %s", err)
|
||||
// Tests the TLS functions of the webserver
|
||||
func TestTLSWebserver(t *testing.T) {
|
||||
var cf Config
|
||||
|
||||
// Test keypairs. testkey1 belongs to the "localhost" host and testkey2 belongs to the "localhost" and "example.com" hosts
|
||||
keypairs := make([]TLSKeypairs, 0)
|
||||
keypairs = append(keypairs, TLSKeypairs{Keyfile: "testkey1.pem", Certificate: "testcert1.pem"})
|
||||
keypairs = append(keypairs, TLSKeypairs{Keyfile: "testkey2.pem", Certificate: "testcert2.pem"})
|
||||
|
||||
// Generate test certificates
|
||||
for i, keypair := range keypairs {
|
||||
if fileExists(keypair.Keyfile) {
|
||||
t.Fatalf("test key '%s' already exists", keypair.Keyfile)
|
||||
return
|
||||
}
|
||||
if fileExists(keypair.Certificate) {
|
||||
t.Fatalf("test certificate '%s' already exists", keypair.Certificate)
|
||||
return
|
||||
}
|
||||
|
||||
hostnames := []string{"localhost"}
|
||||
if i == 1 {
|
||||
hostnames = append(hostnames, "example.com")
|
||||
}
|
||||
if err := generateKeypair(keypair.Keyfile, keypair.Certificate, hostnames); err != nil {
|
||||
t.Fatalf("keypair generation failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func(keypair TLSKeypairs) {
|
||||
os.Remove(keypair.Keyfile)
|
||||
os.Remove(keypair.Certificate)
|
||||
}(keypair)
|
||||
}
|
||||
|
||||
cf.Settings.BindAddress = "localhost:2089"
|
||||
cf.Settings.TLS.Enabled = true
|
||||
cf.Settings.TLS.MinVersion = "1.3"
|
||||
cf.Settings.TLS.MaxVersion = "1.3"
|
||||
cf.Settings.TLS.Keypairs = keypairs
|
||||
cf.Hooks = make([]Hook, 0)
|
||||
cf.Hooks = append(cf.Hooks, Hook{Route: "/test1", Name: "test1", Command: "", Hosts: []string{"localhost"}})
|
||||
cf.Hooks = append(cf.Hooks, Hook{Route: "/test2", Name: "test2", Command: "", Hosts: []string{"localhost", "example.com"}})
|
||||
|
||||
// Setup TLS webserver
|
||||
listener, err := CreateTLSListener(cf)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating tls listener: %s", err)
|
||||
return
|
||||
}
|
||||
server := CreateWebserver(cf)
|
||||
mux := http.NewServeMux()
|
||||
server.Handler = mux
|
||||
if err := RegisterHandlers(cf, mux); err != nil {
|
||||
t.Fatalf("error registering handlers: %s", err)
|
||||
return
|
||||
}
|
||||
go server.Serve(listener)
|
||||
defer func() {
|
||||
if err := server.Shutdown(context.Background()); err != nil {
|
||||
t.Fatalf("error while server shutdown: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// Default page without https should return a 400 error
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/", cf.Settings.BindAddress))
|
||||
if err != nil {
|
||||
t.Fatalf("%s", err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("GET / returns status code %d for default page (400 expected)", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check default page with tls certificates
|
||||
certs := make([]tls.Certificate, 0)
|
||||
rootCAs, _ := x509.SystemCertPool()
|
||||
for i, keypair := range keypairs {
|
||||
x509cert, err := readCertificate(keypair.Certificate)
|
||||
if err != nil {
|
||||
t.Fatalf("error loading certificate %d: %s", i, err)
|
||||
return
|
||||
}
|
||||
raw := make([][]byte, 0)
|
||||
raw = append(raw, x509cert.Raw)
|
||||
certs = append(certs, tls.Certificate{Certificate: raw})
|
||||
rootCAs.AddCert(x509cert)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
Certificates: certs,
|
||||
RootCAs: rootCAs,
|
||||
},
|
||||
}
|
||||
client := http.Client{Transport: transport, Timeout: 15 * time.Second}
|
||||
resp, err = client.Get(fmt.Sprintf("https://%s/", cf.Settings.BindAddress))
|
||||
if err != nil {
|
||||
t.Fatalf("%s", err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("GET / returns status code %d for default page (200 expected)", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func generateKeypair(keyfile string, certfile string, hostnames []string) error {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write key to file
|
||||
var buffer []byte = x509.MarshalPKCS1PrivateKey(key)
|
||||
block := &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: buffer,
|
||||
}
|
||||
file, err := os.Create(keyfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
if err := file.Chmod(os.FileMode(0400)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = pem.Encode(file, block); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate certificate
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(365 * 24 * 10 * time.Hour)
|
||||
|
||||
//Create certificate templet
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(0),
|
||||
Subject: pkix.Name{CommonName: hostnames[0]},
|
||||
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||
DNSNames: hostnames,
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
//Create certificate using templet
|
||||
cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
block = &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert,
|
||||
}
|
||||
|
||||
file, err = os.Create(certfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
if err := file.Chmod(os.FileMode(0644)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = pem.Encode(file, block); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readCertificate(certfile string) (*x509.Certificate, error) {
|
||||
buffer, err := os.ReadFile(certfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p, _ := pem.Decode(buffer)
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("invalid pem file")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(p.Bytes)
|
||||
return cert, err
|
||||
|
||||
}
|
||||
|
||||
func fileExists(filename string) bool {
|
||||
st, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return !st.IsDir()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue