373 lines
8.1 KiB
Go
373 lines
8.1 KiB
Go
package internal
|
|
|
|
import (
|
|
`context`
|
|
`errors`
|
|
`fmt`
|
|
`log`
|
|
`os`
|
|
`strings`
|
|
`sync`
|
|
|
|
`github.com/hashicorp/vault-client-go`
|
|
`github.com/hashicorp/vault-client-go/schema`
|
|
`github.com/jessevdk/go-flags`
|
|
`golang.org/x/term`
|
|
`r00t2.io/goutils/logging`
|
|
`r00t2.io/goutils/multierr`
|
|
`r00t2.io/vault_totp/errs`
|
|
`r00t2.io/vault_totp/version`
|
|
)
|
|
|
|
/*
|
|
GetTotpKey fetches the key info for a key named `keyNm` at TOTP secrets mountpoint `mntPt` using Vault client `vc`.
|
|
|
|
If `mntPt` is empty, it will be set to "totp".
|
|
*/
|
|
func GetTotpKey(ctx context.Context, keyNm, mntPt string, vc *vault.Client) (kinfo map[string]any, err error) {
|
|
|
|
var resp *vault.Response[map[string]interface{}]
|
|
|
|
if strings.TrimSpace(mntPt) == "" {
|
|
mntPt = "totp"
|
|
}
|
|
|
|
if resp, err = vc.Secrets.TotpReadKey(
|
|
ctx,
|
|
keyNm,
|
|
vault.WithMountPath(mntPt),
|
|
); err != nil {
|
|
return
|
|
}
|
|
|
|
kinfo = resp.Data
|
|
|
|
return
|
|
}
|
|
|
|
/*
|
|
GetTotpKeys is like [ListTotpKeys] except it returns the configuration info
|
|
for each key as well. (Except the secret - https://github.com/hashicorp/vault/issues/3043)
|
|
|
|
keyNms, if not specified, will fetch info for all keys on the mountpoint `mntPt`.
|
|
*/
|
|
func GetTotpKeys(ctx context.Context, vc *vault.Client, mntPt string, keyNms ...string) (keyInfo map[string]map[string]any, err error) {
|
|
|
|
var totpNm string
|
|
var mut sync.Mutex
|
|
var wg sync.WaitGroup
|
|
var errChan chan error
|
|
var mErr *multierr.MultiError = multierr.NewMultiError(nil)
|
|
var opts []vault.RequestOption = make([]vault.RequestOption, 0)
|
|
var listResp *vault.Response[schema.StandardListResponse]
|
|
var respErr *vault.ResponseError = new(vault.ResponseError)
|
|
|
|
if strings.TrimSpace(mntPt) != "" {
|
|
opts = append(opts, vault.WithMountPath(mntPt))
|
|
}
|
|
|
|
if keyNms == nil || len(keyNms) == 0 {
|
|
if listResp, err = vc.Secrets.TotpListKeys(
|
|
ctx,
|
|
opts...,
|
|
); err != nil {
|
|
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
|
|
// Is OK; no keys exist yet.
|
|
keyInfo = make(map[string]map[string]any)
|
|
err = nil
|
|
}
|
|
return
|
|
}
|
|
keyNms = listResp.Data.Keys
|
|
}
|
|
|
|
keyInfo = make(map[string]map[string]any)
|
|
if keyNms == nil || len(keyNms) == 0 {
|
|
return
|
|
}
|
|
errChan = make(chan error, len(keyNms))
|
|
for _, totpNm = range keyNms {
|
|
wg.Add(1)
|
|
go getTotpKeyAsync(ctx, totpNm, mntPt, vc, &mut, errChan, &wg, keyInfo)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errChan)
|
|
for err = range errChan {
|
|
if err != nil {
|
|
mErr.AddError(err)
|
|
err = nil
|
|
}
|
|
}
|
|
if !mErr.IsEmpty() {
|
|
err = mErr
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// GetVaultClient returns a Vault client from the provided args.
|
|
func GetVaultClient(args *VaultArgs) (c *vault.Client, err error) {
|
|
|
|
var tok string
|
|
var vc *vault.Client
|
|
var opts []vault.ClientOption
|
|
var vaultTls vault.TLSConfiguration
|
|
|
|
if args == nil {
|
|
err = errs.ErrNilVault
|
|
return
|
|
}
|
|
|
|
tok = args.Token
|
|
if err = GetVaultToken(&tok); err != nil {
|
|
return
|
|
}
|
|
|
|
opts = []vault.ClientOption{vault.WithAddress(args.Addr)}
|
|
if args.Insecure || args.SniName != nil {
|
|
vaultTls = vault.TLSConfiguration{
|
|
InsecureSkipVerify: args.Insecure,
|
|
}
|
|
if args.SniName != nil {
|
|
vaultTls.ServerName = *args.SniName
|
|
}
|
|
opts = append(
|
|
opts,
|
|
vault.WithTLS(vaultTls),
|
|
)
|
|
}
|
|
|
|
if vc, err = vault.New(opts...); err != nil {
|
|
return
|
|
}
|
|
if err = vc.SetToken(tok); err != nil {
|
|
return
|
|
}
|
|
|
|
c = vc
|
|
|
|
return
|
|
}
|
|
|
|
// GetVaultToken standardizes the Vault token fetching/lookup.
|
|
func GetVaultToken(tok *string) (err error) {
|
|
|
|
var p1 []byte
|
|
var oldState *term.State
|
|
var termFd int = int(os.Stdin.Fd())
|
|
|
|
if tok != nil && len(strings.TrimSpace(*tok)) > 0 {
|
|
return
|
|
}
|
|
|
|
if oldState, err = term.GetState(termFd); err != nil {
|
|
return
|
|
}
|
|
fmt.Println("Vault token needed.\nVault token (will not be echoed back):")
|
|
defer func() {
|
|
if err = term.Restore(termFd, oldState); err != nil {
|
|
log.Println("restore failed:", err)
|
|
}
|
|
}()
|
|
|
|
if p1, err = term.ReadPassword(termFd); err != nil {
|
|
return
|
|
}
|
|
|
|
if tok == nil {
|
|
tok = new(string)
|
|
}
|
|
*tok = string(p1)
|
|
|
|
return
|
|
}
|
|
|
|
/*
|
|
ListTotpKeys returns a list (map, really, for easier lookup) of TOTP key names at `mntpt`
|
|
with [github.com/hashicorp/vault-client-go.Client] `vc` and [context.Context] `ctx`.
|
|
If `mntpt` is empty, the default ("totp") will be used.
|
|
|
|
If no TOTP keys are found at the mount, `keyNms` will be empty but not nil.
|
|
|
|
See [ListTotpKeys] if you want additional information about each key.
|
|
*/
|
|
func ListTotpKeys(ctx context.Context, vc *vault.Client, mntPt string) (keyNms map[string]struct{}, err error) {
|
|
|
|
var totpNm string
|
|
var opts []vault.RequestOption = make([]vault.RequestOption, 0)
|
|
var listResp *vault.Response[schema.StandardListResponse]
|
|
var respErr *vault.ResponseError = new(vault.ResponseError)
|
|
|
|
if strings.TrimSpace(mntPt) != "" {
|
|
opts = append(opts, vault.WithMountPath(mntPt))
|
|
}
|
|
|
|
if listResp, err = vc.Secrets.TotpListKeys(
|
|
ctx,
|
|
opts...,
|
|
); err != nil {
|
|
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
|
|
// Is OK; no keys exist yet.
|
|
keyNms = make(map[string]struct{})
|
|
err = nil
|
|
} else {
|
|
return
|
|
}
|
|
} else {
|
|
keyNms = make(map[string]struct{})
|
|
for _, totpNm = range listResp.Data.Keys {
|
|
keyNms[totpNm] = struct{}{}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// PrepParser properly initializes the parser and logger in a standardized way.
|
|
func PrepParser(cmd string, args CommonArgs, p *flags.Parser) (doExit bool, err error) {
|
|
|
|
var logFlagsRuntime int = logFlags
|
|
var flagsErr *flags.Error = new(flags.Error)
|
|
|
|
p.NamespaceDelimiter = ParseNsDelim
|
|
p.EnvNamespaceDelimiter = ParseEnvNsDelim
|
|
|
|
if _, err = p.Parse(); err != nil {
|
|
switch {
|
|
case errors.As(err, &flagsErr):
|
|
switch {
|
|
case errors.Is(flagsErr.Type, flags.ErrHelp),
|
|
errors.Is(flagsErr.Type, flags.ErrCommandRequired),
|
|
errors.Is(flagsErr.Type, flags.ErrRequired):
|
|
// These print their relevant messages by themselves.
|
|
err = nil
|
|
return
|
|
default:
|
|
return
|
|
}
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
|
|
if version.Ver, err = version.Version(); err != nil {
|
|
return
|
|
}
|
|
|
|
// If args.Version or args.DetailVersion are true, just print them and exit.
|
|
if args.DetailVersion || args.Version {
|
|
doExit = true
|
|
if args.Version {
|
|
fmt.Println(version.Ver.Short())
|
|
return
|
|
} else if args.DetailVersion {
|
|
fmt.Println(version.Ver.Detail())
|
|
return
|
|
}
|
|
}
|
|
|
|
if args.DoDebug {
|
|
logFlagsRuntime = logFlagsDebug
|
|
}
|
|
Logger = logging.GetMultiLogger(
|
|
args.DoDebug,
|
|
fmt.Sprintf(
|
|
"Vault TOTP [%s_%s]",
|
|
cmdPfx, cmd,
|
|
),
|
|
)
|
|
if err = Logger.AddDefaultLogger(
|
|
"default",
|
|
logFlagsRuntime,
|
|
"/var/log/vault_totp/vault_totp.log", "~/logs/vault_totp.log",
|
|
); err != nil {
|
|
log.Panicln(err)
|
|
}
|
|
if err = Logger.Setup(); err != nil {
|
|
log.Panicln(err)
|
|
}
|
|
Logger.Info("main: Vault TOTP version %v", version.Ver.Short())
|
|
Logger.Debug("main: Vault TOTP version (extended):\n%v", version.Ver.Detail())
|
|
|
|
return
|
|
}
|
|
|
|
// SplitVaultPathspec splits a <mount>[:<path>] into separate components. If no path is provided or it is empty, "/" will be used.
|
|
func SplitVaultPathspec(spec string) (mount, secretPath string) {
|
|
|
|
var spl []string = strings.SplitN(spec, ":", 2)
|
|
|
|
mount = spl[0]
|
|
switch len(spl) {
|
|
case 1:
|
|
secretPath = "/"
|
|
case 2:
|
|
if strings.TrimSpace(spl[1]) == "" {
|
|
secretPath = "/"
|
|
} else {
|
|
secretPath = spl[1]
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// SplitVaultPathspec2 splits a [<mount>:]<path> into separate components.
|
|
func SplitVaultPathspec2(spec string) (mount, secretPath string) {
|
|
|
|
var spl []string = strings.SplitN(spec, ":", 2)
|
|
|
|
switch len(spl) {
|
|
case 1:
|
|
secretPath = spl[0]
|
|
case 2:
|
|
mount = spl[0]
|
|
if strings.TrimSpace(spl[1]) == "" {
|
|
secretPath = "/"
|
|
} else {
|
|
secretPath = spl[1]
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Validate validates the passed struct `s`. A nil err means validation succeeded.
|
|
func Validate(s any) (err error) {
|
|
|
|
if err = validate.Struct(s); err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// getTotpKeyAsync fetches key info for key named `nm` from Vault `vc`.
|
|
func getTotpKeyAsync(
|
|
ctx context.Context,
|
|
keyNm, mntPt string,
|
|
vc *vault.Client,
|
|
mut *sync.Mutex, errChan chan error, wg *sync.WaitGroup,
|
|
m map[string]map[string]any,
|
|
) {
|
|
|
|
var err error
|
|
var kinfo map[string]any
|
|
|
|
defer wg.Done()
|
|
|
|
if kinfo, err = GetTotpKey(ctx, keyNm, mntPt, vc); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
|
|
// We can wait to hold the lock until we're actually inserting into the map.
|
|
mut.Lock()
|
|
defer mut.Unlock()
|
|
m[keyNm] = kinfo
|
|
|
|
return
|
|
}
|