go_sysutils/cryptparse/funcs.go
brent saner db20c70d86
v1.4.0
ADDED:
* cryptparse, which makes managing certs/keys MUCH much easier (and
  better) than the stdlib utilities.
2024-06-21 17:18:19 -04:00

752 lines
17 KiB
Go

package cryptparse
import (
`bytes`
`crypto`
`crypto/ecdh`
`crypto/ecdsa`
`crypto/ed25519`
`crypto/rsa`
`crypto/tls`
`crypto/x509`
`encoding/pem`
`errors`
`net/url`
`os`
`strconv`
`strings`
`r00t2.io/sysutils/paths`
)
// FromURL returns a *TlsUri from a *url.URL.
func FromURL(u *url.URL) (t *TlsUri) {
var newU *url.URL
if u == nil {
return
}
newU = new(url.URL)
*newU = *u
if u.User != nil {
newU.User = new(url.Userinfo)
*newU.User = *u.User
}
newU.Scheme = "tls"
t = &TlsUri{
URL: newU,
}
return
}
// IsMatchedPair returns true if the privateKey is paired with the cert.
func IsMatchedPair(privKey crypto.PrivateKey, cert *x509.Certificate) (isMatched bool, err error) {
var pubkey crypto.PublicKey
if cert == nil || privKey == nil {
return
}
pubkey = cert.PublicKey
switch k := privKey.(type) {
case *rsa.PrivateKey:
if p, ok := pubkey.(*rsa.PublicKey); ok {
isMatched = k.PublicKey.Equal(p)
return
}
case ed25519.PrivateKey:
if p, ok := pubkey.(ed25519.PublicKey); ok {
// Order is flipped here because unlike the other key types, an ed25519.PrivateKey is just a []byte.
isMatched = p.Equal(k.Public())
return
}
case *ecdh.PrivateKey:
if p, ok := pubkey.(*ecdh.PublicKey); ok {
isMatched = k.PublicKey().Equal(p)
return
}
case *ecdsa.PrivateKey:
if p, ok := pubkey.(*ecdsa.PublicKey); ok {
isMatched = k.PublicKey.Equal(p)
return
}
}
// If we got here, we can't determine either the private key type or the cert's public key type.
err = ErrUnknownKey
return
}
/*
ParseTlsCipher parses string s and attempts to derive a TLS cipher suite (as a uint16) from it.
Use ParseTlsCipherSuite if you wish for a tls.CipherSuite instead.
The string may either be the name (as per https://www.iana.org/assignments/tls-parameters/tls-parameters.xml)
or an int (normal, hex, etc. string representation).
If none is found, the default is MaxTlsCipher.
*/
func ParseTlsCipher(s string) (cipherSuite uint16, err error) {
var nm string
var n uint64
var i uint16
var ok bool
if n, err = strconv.ParseUint(s, 10, 16); err != nil {
if errors.Is(err, strconv.ErrSyntax) {
// It's a name; parse below.
err = nil
} else {
return
}
} else {
// It's a number.
if nm = tls.CipherSuiteName(uint16(n)); strings.HasPrefix(nm, "0x") {
// ...but invalid.
err = ErrBadTlsCipher
return
} else {
// Valid (as number). Return it.
cipherSuite = uint16(n)
return
}
}
s = strings.ToUpper(s)
s = strings.ReplaceAll(s, " ", "_")
// We build a dynamic map of cipher suite names to uint16s (if not already created).
if tlsCipherNmToUint == nil {
tlsCipherNmToUint = make(map[string]uint16)
for i = 0; i <= MaxTlsCipher; i++ {
if nm = tls.VersionName(i); !strings.HasPrefix(nm, "0x") {
tlsCipherNmToUint[nm] = i
}
}
}
cipherSuite = MaxTlsCipher
if i, ok = tlsCipherNmToUint[s]; ok {
cipherSuite = i
}
return
}
/*
ParseTlsCiphers parses s as a comma-separated list of cipher suite names/integers and returns a slice of suites.
See ParseTlsCipher for details, as this is mostly just a wrapper around it.
If no cipher suites are found, cipherSuites will only contain MaxTlsCipher.
*/
func ParseTlsCiphers(s string) (cipherSuites []uint16) {
var suiteNms []string
var cipher uint16
var err error
suiteNms = strings.Split(s, ",")
cipherSuites = make([]uint16, 0, len(suiteNms))
for _, nm := range suiteNms {
if cipher, err = ParseTlsCipher(nm); err != nil {
err = nil
continue
}
cipherSuites = append(cipherSuites, cipher)
}
if len(cipherSuites) == 0 {
cipherSuites = []uint16{tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256}
}
return
}
// ParseTlsCipherSuite is like ParseTlsCipher but returns a *tls.CipherSuite instead of a uint16 TLS cipher identifier.
func ParseTlsCipherSuite(s string) (cipherSuite *tls.CipherSuite, err error) {
var cipherId uint16
if cipherId, err = ParseTlsCipher(s); err != nil {
return
}
for _, v := range tls.CipherSuites() {
if v.ID == cipherId {
cipherSuite = v
return
}
}
for _, v := range tls.InsecureCipherSuites() {
if v.ID == cipherId {
cipherSuite = v
return
}
}
return
}
// ParseTlsCipherSuites is like ParseTlsCiphers but returns a []*tls.CipherSuite instead of a []uint16 of TLS cipher identifiers.
func ParseTlsCipherSuites(s string) (cipherSuites []*tls.CipherSuite, err error) {
var found bool
var cipherIds []uint16
cipherIds = ParseTlsCiphers(s)
for _, cipherId := range cipherIds {
found = false
for _, v := range tls.CipherSuites() {
if v.ID == cipherId {
cipherSuites = append(cipherSuites, v)
found = true
break
}
}
if !found {
for _, v := range tls.InsecureCipherSuites() {
if v.ID == cipherId {
cipherSuites = append(cipherSuites, v)
break
}
}
}
}
return
}
/*
ParseTlsCurve parses string s and attempts to derive a tls.CurveID from it.
The string may either be the name (as per // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8)
or an int (normal, hex, etc. string representation).
*/
func ParseTlsCurve(s string) (curve tls.CurveID, err error) {
var i tls.CurveID
var n uint64
var ok bool
if n, err = strconv.ParseUint(s, 10, 16); err != nil {
if errors.Is(err, strconv.ErrSyntax) {
// It's a name; parse below.
err = nil
} else {
return
}
} else {
// It's a number.
if strings.HasPrefix(tls.CurveID(uint16(n)).String(), "CurveID(") {
// ...but invalid.
err = ErrBadTlsCurve
return
} else {
// Valid (as number). Return it.
curve = tls.CurveID(uint16(n))
return
}
}
// It seems to be a name. Normalize...
s = strings.ToUpper(s)
// Unfortunately there's no "tls.CurveIDName()" function.
// They do have a .String() method though.
if tlsCurveNmToCurve == nil {
tlsCurveNmToCurve = make(map[string]tls.CurveID)
for i = 0; i <= MaxCurveId; i++ {
if strings.HasPrefix(i.String(), "CurveID(") {
continue
}
tlsCurveNmToCurve[i.String()] = i
// It's normally mixed-case; we want to be able to look it up in a normalized all-caps as well.
tlsCurveNmToCurve[strings.ToUpper(i.String())] = i
// The normal name, except for X25519, has "Curve" in the front. We add it without that prefix as well.
tlsCurveNmToCurve[strings.TrimPrefix(i.String(), "Curve")] = i
}
}
curve = MaxCurveId
if _, ok = tlsCurveNmToCurve[s]; ok {
curve = tlsCurveNmToCurve[s]
}
return
}
/*
ParseTlsCurves parses s as a comma-separated list of tls.CurveID names/integers and returns a slice of tls.CurveID.
See ParseTlsCurve for details, as this is mostly just a wrapper around it.
If no curves are found, curves will only contain MaxCurveId.
*/
func ParseTlsCurves(s string) (curves []tls.CurveID) {
var curveNms []string
var curve tls.CurveID
var err error
curveNms = strings.Split(s, ",")
curves = make([]tls.CurveID, 0, len(curveNms))
for _, nm := range curveNms {
if curve, err = ParseTlsCurve(nm); err != nil {
err = nil
continue
}
curves = append(curves, curve)
}
if len(curves) == 0 {
curves = []tls.CurveID{MaxCurveId}
}
return
}
/*
ParseTlsUri parses a "TLS URI"'s query parameters. All certs and keys must be in PEM format.
You probably don't need this and should instead be using TlsUri.ToTlsConfig.
It just wraps this, but is probably more convenient.
*/
func ParseTlsUri(tlsUri *url.URL) (tlsConf *tls.Config, err error) {
var b []byte
var rootCAs *x509.CertPool
var intermediateCAs []*x509.Certificate
var privKeys []crypto.PrivateKey
var tlsCerts []tls.Certificate
var allowInvalid bool
var ciphers []uint16
var curves []tls.CurveID
var params map[string][]string
var ok bool
var val string
var minVer uint16
var maxVer uint16
var buf *bytes.Buffer = new(bytes.Buffer)
var srvNm string = tlsUri.Hostname()
params = tlsUri.Query()
if params == nil {
tlsConf = &tls.Config{
ServerName: srvNm,
}
return
}
// These are all filepath(s).
for _, k := range []string{
TlsUriParamCa,
TlsUriParamCert,
TlsUriParamKey,
} {
if _, ok = params[k]; ok {
for idx, _ := range params[k] {
if err = paths.RealPath(&params[k][idx]); err != nil {
return
}
}
}
}
// CA cert(s).
buf.Reset()
if _, ok = params[TlsUriParamCa]; ok {
rootCAs = x509.NewCertPool()
for _, c := range params[TlsUriParamCa] {
if b, err = os.ReadFile(c); err != nil {
if errors.Is(err, os.ErrNotExist) {
err = nil
continue
}
}
buf.Write(b)
}
if rootCAs, _, intermediateCAs, err = ParseCA(buf.Bytes()); err != nil {
return
}
} else {
if rootCAs, err = x509.SystemCertPool(); err != nil {
return
}
}
// Keys. These are done first so we can match to a client certificate.
buf.Reset()
if _, ok = params[TlsUriParamKey]; ok {
for _, k := range params[TlsUriParamKey] {
if b, err = os.ReadFile(k); err != nil {
if errors.Is(err, os.ErrNotExist) {
err = nil
continue
} else {
return
}
}
buf.Write(b)
}
if privKeys, err = ParsePrivateKey(buf.Bytes()); err != nil {
return
}
}
// (Client) Certificate(s).
buf.Reset()
if _, ok = params[TlsUriParamCert]; ok {
for _, c := range params[TlsUriParamCert] {
if b, err = os.ReadFile(c); err != nil {
if errors.Is(err, os.ErrNotExist) {
err = nil
continue
} else {
return
}
}
buf.Write(b)
}
if tlsCerts, err = ParseLeafCert(buf.Bytes(), privKeys, intermediateCAs...); err != nil {
return
}
}
// Hostname (Override).
if _, ok = params[TlsUriParamSni]; ok {
srvNm = params[TlsUriParamSni][0]
}
// Disable Verification.
if _, ok = params[TlsUriParamNoVerify]; ok {
val = strings.ToLower(params[TlsUriParamNoVerify][0])
for _, i := range paramBoolValsTrue {
if i == val {
allowInvalid = true
break
}
}
}
// Ciphers.
if _, ok = params[TlsUriParamCipher]; ok {
ciphers = ParseTlsCiphers(strings.Join(params[TlsUriParamCipher], ","))
}
// Minimum TLS Protocol Version.
if _, ok = params[TlsUriParamMinTls]; ok {
if minVer, err = ParseTlsVersion(params[TlsUriParamMinTls][0]); err != nil {
return
}
}
// Maximum TLS Protocol Version.
if _, ok = params[TlsUriParamMaxTls]; ok {
if maxVer, err = ParseTlsVersion(params[TlsUriParamMaxTls][0]); err != nil {
return
}
}
// Curves.
if _, ok = params[TlsUriParamCurve]; ok {
curves = ParseTlsCurves(strings.Join(params[TlsUriParamCurve], ","))
}
tlsConf = &tls.Config{
Certificates: tlsCerts,
RootCAs: rootCAs,
ServerName: srvNm,
InsecureSkipVerify: allowInvalid,
CipherSuites: ciphers,
MinVersion: minVer,
MaxVersion: maxVer,
CurvePreferences: curves,
}
return
}
// ParseTlsVersion parses string s and attempts to derive a TLS version from it. If none is found, tlsVer will be 0.
func ParseTlsVersion(s string) (tlsVer uint16, err error) {
var nm string
var n uint64
var i uint16
var ok bool
if n, err = strconv.ParseUint(s, 10, 16); err != nil {
if errors.Is(err, strconv.ErrSyntax) {
// It's a name; parse below.
err = nil
} else {
return
}
} else {
// It's a number.
if nm = tls.VersionName(uint16(n)); strings.HasPrefix(nm, "0x") {
// ...but invalid.
err = ErrBadTlsVer
return
} else {
// Valid (as number). Return it.
tlsVer = uint16(n)
return
}
}
// If we get here, it should be parsed as a version string.
s = strings.ToUpper(s)
s = strings.ReplaceAll(s, "_", " ")
s = strings.ReplaceAll(s, "V", " ")
s = strings.TrimSpace(s)
if !strings.HasPrefix(s, "SSL") && !strings.HasPrefix(s, "TLS ") {
s = "TLS " + s
}
// We build a dynamic map of version names to uint16s (if not already created).
if tlsVerNmToUint == nil {
tlsVerNmToUint = make(map[string]uint16)
for i = MinTlsVer; i <= MaxTlsVer; i++ {
if nm = tls.VersionName(i); !strings.HasPrefix(nm, "0x") {
tlsVerNmToUint[nm] = i
}
}
}
if i, ok = tlsVerNmToUint[s]; ok {
tlsVer = i
}
return
}
/*
ParseCA parses PEM bytes and returns an *x509.CertPool in caCerts.
Concatenated PEM files are supported.
Any keys found will be filtered out, as will any leaf certificates.
Any *intermediate* CAs (the certificate is a CA but it is not self-signed) will be returned separate from
certPool.
Ordering from the file is preserved in the returned slices.
*/
func ParseCA(certRaw []byte) (certPool *x509.CertPool, rootCerts []*x509.Certificate, intermediateCerts []*x509.Certificate, err error) {
var pemBlocks []*pem.Block
var cert *x509.Certificate
var certs []*x509.Certificate
if pemBlocks, err = SplitPem(certRaw); err != nil {
return
}
// Filter out keys etc. and non-CA certs.
for _, b := range pemBlocks {
if b.Type != "CERTIFICATE" {
continue
}
if cert, err = x509.ParseCertificate(b.Bytes); err != nil {
return
}
if !cert.IsCA {
continue
}
certs = append(certs, cert)
}
for _, cert = range certs {
if bytes.Equal(cert.RawIssuer, cert.RawSubject) {
// It's a root/self-signed.
rootCerts = append(rootCerts, cert)
} else {
// It's an intermediate.
intermediateCerts = append(intermediateCerts, cert)
}
}
if rootCerts != nil {
certPool = x509.NewCertPool()
for _, cert = range rootCerts {
certPool.AddCert(cert)
}
}
return
}
/*
ParseLeafCert parses PEM bytes from a (client) certificate file, iterates over a slice of
crypto.PrivateKey (finding one that matches), and returns one (or more) tls.Certificate.
The key may also be combined with the certificate in the same file.
If no private key matches or no client cert is found in the file, tlsCerts will be nil/missing
that certificate but no error will be returned.
This behavior can be avoided by passing a nil slice to keys.
Any leaf certificates ("server" certificate, as opposed to a signer/issuer) found in the file
will be assumed to be the desired one(s).
Any additional/supplementary intermediates may be provided. Any present in the PEM bytes (certRaw) will be included.
Any *root* CAs found will be discarded. They should/can be extracted seperately via ParseCA.
The parsed and paired certificates and keys can be found in each respective tls.Certificate.Leaf and tls.Certificate.PrivateKey.
Any certs without a corresponding key will be discarded.
*/
func ParseLeafCert(certRaw []byte, keys []crypto.PrivateKey, intermediates ...*x509.Certificate) (tlsCerts []tls.Certificate, err error) {
var pemBlocks []*pem.Block
var cert *x509.Certificate
var certs []*x509.Certificate
var caCerts []*x509.Certificate
var parsedKeys []crypto.PrivateKey
var isMatched bool
var foundKey crypto.PrivateKey
var interBytes [][]byte
var skipKeyPair bool = keys == nil
var parsedKeysBuf *bytes.Buffer = new(bytes.Buffer)
if pemBlocks, err = SplitPem(certRaw); err != nil {
return
}
for _, b := range pemBlocks {
if strings.Contains(b.Type, "PRIVATE KEY") {
parsedKeysBuf.Write(pem.EncodeToMemory(b))
continue
}
if b.Type != "CERTIFICATE" {
continue
}
if cert, err = x509.ParseCertificate(b.Bytes); err != nil {
return
}
if cert.IsCA {
if bytes.Equal(cert.RawIssuer, cert.RawSubject) {
caCerts = append(caCerts, cert)
} else {
intermediates = append(intermediates, cert)
}
}
certs = append(certs, cert)
}
if intermediates != nil && len(intermediates) != 0 {
interBytes = make([][]byte, len(intermediates))
for _, i := range intermediates {
interBytes = append(interBytes, i.Raw)
}
}
if parsedKeysBuf.Len() != 0 {
if parsedKeys, err = ParsePrivateKey(parsedKeysBuf.Bytes()); err != nil {
return
}
keys = append(keys, parsedKeys...)
}
// Now pair the certs and keys, and combine as a tls.Certificate.
for _, cert = range certs {
foundKey = nil
for _, k := range keys {
if isMatched, err = IsMatchedPair(k, cert); err != nil {
return
}
if isMatched {
foundKey = k
break
}
}
if foundKey == nil && !skipKeyPair {
continue
}
tlsCerts = append(
tlsCerts,
tls.Certificate{
Certificate: append([][]byte{cert.Raw}, interBytes...),
PrivateKey: foundKey,
Leaf: cert,
},
)
}
_ = caCerts
return
}
/*
ParsePrivateKey parses PEM bytes to a private key. Multiple keys may be concatenated in the same file.
Any public keys, certificates, etc. found will be discarded.
*/
func ParsePrivateKey(keyRaw []byte) (keys []crypto.PrivateKey, err error) {
var privKey crypto.PrivateKey
var pemBlocks []*pem.Block
if pemBlocks, err = SplitPem(keyRaw); err != nil {
return
}
for _, b := range pemBlocks {
if !strings.Contains(b.Type, "PRIVATE KEY") {
continue
}
switch b.Type {
case "RSA PRIVATE KEY": // PKCS#1
if privKey, err = x509.ParsePKCS1PrivateKey(b.Bytes); err != nil {
return
}
keys = append(keys, privKey)
case "EC PRIVATE KEY": // SEC 1, ASN.1 DER
if privKey, err = x509.ParseECPrivateKey(b.Bytes); err != nil {
return
}
keys = append(keys, privKey)
case "PRIVATE KEY": // PKCS#8
if privKey, err = x509.ParsePKCS8PrivateKey(b.Bytes); err != nil {
return
}
keys = append(keys, privKey)
default:
err = ErrUnknownKey
return
}
}
// TODO
return
}
// SplitPem splits a single block of bytes into one (or more) (encoding/)pem.Blocks.
func SplitPem(pemRaw []byte) (blocks []*pem.Block, err error) {
var block *pem.Block
var rest []byte
for block, rest = pem.Decode(pemRaw); block != nil; block, rest = pem.Decode(rest) {
blocks = append(blocks, block)
}
return
}