mirror of
https://github.com/siyuan-note/siyuan.git
synced 2026-02-08 08:14:21 +01:00
* Add use TLS for network serving configuration option * kernel: Implement TLS certificate generation * kernel: server: Use https for fixed port proxy when needed * Allow exporting the CA Certificate file * Implement import and export of CA Certs * kernel: fixedport: Use the same port for HTTP and HTTPS
337 lines
8.9 KiB
Go
337 lines
8.9 KiB
Go
// SiYuan - Refactor your thinking
|
|
// Copyright (c) 2020-present, b3log.org
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package util
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/88250/gulu"
|
|
"github.com/siyuan-note/logging"
|
|
)
|
|
|
|
const (
|
|
TLSCACertFilename = "ca.crt"
|
|
TLSCAKeyFilename = "ca.key"
|
|
TLSCertFilename = "cert.pem"
|
|
TLSKeyFilename = "key.pem"
|
|
)
|
|
|
|
// Returns paths to existing TLS certificates or generates new ones signed by a local CA.
|
|
// Certificates are stored in the conf directory of the workspace.
|
|
func GetOrCreateTLSCert() (certPath, keyPath string, err error) {
|
|
certPath = filepath.Join(ConfDir, TLSCertFilename)
|
|
keyPath = filepath.Join(ConfDir, TLSKeyFilename)
|
|
caCertPath := filepath.Join(ConfDir, TLSCACertFilename)
|
|
caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename)
|
|
|
|
if !gulu.File.IsExist(caCertPath) || !gulu.File.IsExist(caKeyPath) {
|
|
logging.LogInfof("generating local CA for TLS...")
|
|
if err = generateCACert(caCertPath, caKeyPath); err != nil {
|
|
logging.LogErrorf("failed to generate CA certificates: %s", err)
|
|
return "", "", err
|
|
}
|
|
}
|
|
|
|
if gulu.File.IsExist(certPath) && gulu.File.IsExist(keyPath) {
|
|
if validateCert(certPath) {
|
|
logging.LogInfof("using existing TLS certificates from [%s]", ConfDir)
|
|
return certPath, keyPath, nil
|
|
}
|
|
logging.LogInfof("existing TLS certificates are invalid or expired, regenerating...")
|
|
}
|
|
|
|
caCert, caKey, err := loadCA(caCertPath, caKeyPath)
|
|
if err != nil {
|
|
logging.LogErrorf("failed to load CA certificates: %s", err)
|
|
return "", "", err
|
|
}
|
|
|
|
logging.LogInfof("generating TLS server certificates signed by local CA...")
|
|
if err = generateServerCert(certPath, keyPath, caCert, caKey); err != nil {
|
|
logging.LogErrorf("failed to generate TLS certificates: %s", err)
|
|
return "", "", err
|
|
}
|
|
|
|
logging.LogInfof("generated TLS certificates at [%s]", ConfDir)
|
|
return certPath, keyPath, nil
|
|
}
|
|
|
|
// Checks if the certificate file exists, is not expired, and contains all current IP addresses
|
|
func validateCert(certPath string) bool {
|
|
certPEM, err := os.ReadFile(certPath)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
block, _ := pem.Decode(certPEM)
|
|
if block == nil {
|
|
return false
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if certificate is still valid, with 7 day buffer
|
|
if !time.Now().Add(7 * 24 * time.Hour).Before(cert.NotAfter) {
|
|
return false
|
|
}
|
|
|
|
// Check if certificate contains all current IP addresses
|
|
currentIPs := GetServerAddrs()
|
|
certIPMap := make(map[string]bool)
|
|
for _, ip := range cert.IPAddresses {
|
|
certIPMap[ip.String()] = true
|
|
}
|
|
|
|
for _, ipStr := range currentIPs {
|
|
ipStr = trimIPv6Brackets(ipStr)
|
|
ip := net.ParseIP(ipStr)
|
|
if ip == nil {
|
|
continue
|
|
}
|
|
|
|
if !certIPMap[ip.String()] {
|
|
logging.LogInfof("certificate missing current IP address [%s], will regenerate", ip.String())
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// Creates a new self-signed CA certificate
|
|
func generateCACert(certPath, keyPath string) error {
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
notBefore := time.Now()
|
|
notAfter := notBefore.Add(10 * 365 * 24 * time.Hour)
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: serialNumber,
|
|
Subject: pkix.Name{
|
|
Organization: []string{"SiYuan"},
|
|
CommonName: "SiYuan Local CA",
|
|
},
|
|
NotBefore: notBefore,
|
|
NotAfter: notAfter,
|
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return writeCertAndKey(certPath, keyPath, certDER, privateKey)
|
|
}
|
|
|
|
// Creates a new server certificate signed by the CA
|
|
func generateServerCert(certPath, keyPath string, caCert *x509.Certificate, caKey any) error {
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
notBefore := time.Now()
|
|
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
|
|
|
ipAddresses := []net.IP{
|
|
net.ParseIP("127.0.0.1"),
|
|
net.IPv6loopback,
|
|
}
|
|
|
|
localIPs := GetServerAddrs()
|
|
for _, ipStr := range localIPs {
|
|
ipStr = trimIPv6Brackets(ipStr)
|
|
if ip := net.ParseIP(ipStr); ip != nil {
|
|
ipAddresses = append(ipAddresses, ip)
|
|
}
|
|
}
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: serialNumber,
|
|
Subject: pkix.Name{
|
|
Organization: []string{"SiYuan"},
|
|
CommonName: "SiYuan Local Server",
|
|
},
|
|
NotBefore: notBefore,
|
|
NotAfter: notAfter,
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: ipAddresses,
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, caCert, &privateKey.PublicKey, caKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return writeCertAndKey(certPath, keyPath, certDER, privateKey)
|
|
}
|
|
|
|
// Loads the CA certificate and private key from files
|
|
func loadCA(certPath, keyPath string) (*x509.Certificate, any, error) {
|
|
certPEM, err := os.ReadFile(certPath)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
block, _ := pem.Decode(certPEM)
|
|
if block == nil {
|
|
return nil, nil, fmt.Errorf("failed to decode CA certificate PEM")
|
|
}
|
|
|
|
caCert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
keyPEM, err := os.ReadFile(keyPath)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
block, _ = pem.Decode(keyPEM)
|
|
if block == nil {
|
|
return nil, nil, fmt.Errorf("failed to decode CA key PEM")
|
|
}
|
|
|
|
caKey, err := x509.ParseECPrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return caCert, caKey, nil
|
|
}
|
|
|
|
func writeCertAndKey(certPath, keyPath string, certDER []byte, privateKey *ecdsa.PrivateKey) error {
|
|
certFile, err := os.Create(certPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer certFile.Close()
|
|
|
|
if err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
|
return err
|
|
}
|
|
|
|
keyFile, err := os.Create(keyPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer keyFile.Close()
|
|
|
|
keyDER, err := x509.MarshalECPrivateKey(privateKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Imports a CA certificate and private key from PEM-encoded strings.
|
|
func ImportCABundle(caCertPEM, caKeyPEM string) error {
|
|
certBlock, _ := pem.Decode([]byte(caCertPEM))
|
|
if certBlock == nil {
|
|
return fmt.Errorf("failed to decode CA certificate PEM")
|
|
}
|
|
|
|
caCert, err := x509.ParseCertificate(certBlock.Bytes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse CA certificate: %w", err)
|
|
}
|
|
|
|
if !caCert.IsCA {
|
|
return fmt.Errorf("the provided certificate is not a CA certificate")
|
|
}
|
|
|
|
keyBlock, _ := pem.Decode([]byte(caKeyPEM))
|
|
if keyBlock == nil {
|
|
return fmt.Errorf("failed to decode CA private key PEM")
|
|
}
|
|
|
|
_, err = x509.ParseECPrivateKey(keyBlock.Bytes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse CA private key: %w", err)
|
|
}
|
|
|
|
caCertPath := filepath.Join(ConfDir, TLSCACertFilename)
|
|
caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename)
|
|
|
|
if err := os.WriteFile(caCertPath, []byte(caCertPEM), 0644); err != nil {
|
|
return fmt.Errorf("failed to write CA certificate: %w", err)
|
|
}
|
|
|
|
if err := os.WriteFile(caKeyPath, []byte(caKeyPEM), 0600); err != nil {
|
|
return fmt.Errorf("failed to write CA private key: %w", err)
|
|
}
|
|
|
|
certPath := filepath.Join(ConfDir, TLSCertFilename)
|
|
keyPath := filepath.Join(ConfDir, TLSKeyFilename)
|
|
|
|
if gulu.File.IsExist(certPath) {
|
|
os.Remove(certPath)
|
|
}
|
|
if gulu.File.IsExist(keyPath) {
|
|
os.Remove(keyPath)
|
|
}
|
|
|
|
logging.LogInfof("imported CA bundle, server certificate will be regenerated on next TLS initialization")
|
|
return nil
|
|
}
|
|
|
|
// trimIPv6Brackets removes brackets from IPv6 address strings like "[::1]"
|
|
func trimIPv6Brackets(ip string) string {
|
|
if len(ip) > 2 && ip[0] == '[' && ip[len(ip)-1] == ']' {
|
|
return ip[1 : len(ip)-1]
|
|
}
|
|
return ip
|
|
}
|