1
0
Files
vault_totp/internal/funcs.go
2025-12-23 20:58:56 -05:00

373 lines
8.1 KiB
Go

package internal
import (
`context`
`errors`
`fmt`
`log`
`os`
`strings`
`sync`
`github.com/hashicorp/vault-client-go`
`github.com/hashicorp/vault-client-go/schema`
`github.com/jessevdk/go-flags`
`golang.org/x/term`
`r00t2.io/goutils/logging`
`r00t2.io/goutils/multierr`
`r00t2.io/vault_totp/errs`
`r00t2.io/vault_totp/version`
)
/*
GetTotpKey fetches the key info for a key named `keyNm` at TOTP secrets mountpoint `mntPt` using Vault client `vc`.
If `mntPt` is empty, it will be set to "totp".
*/
func GetTotpKey(ctx context.Context, keyNm, mntPt string, vc *vault.Client) (kinfo map[string]any, err error) {
var resp *vault.Response[map[string]interface{}]
if strings.TrimSpace(mntPt) == "" {
mntPt = "totp"
}
if resp, err = vc.Secrets.TotpReadKey(
ctx,
keyNm,
vault.WithMountPath(mntPt),
); err != nil {
return
}
kinfo = resp.Data
return
}
/*
GetTotpKeys is like [ListTotpKeys] except it returns the configuration info
for each key as well. (Except the secret - https://github.com/hashicorp/vault/issues/3043)
keyNms, if not specified, will fetch info for all keys on the mountpoint `mntPt`.
*/
func GetTotpKeys(ctx context.Context, vc *vault.Client, mntPt string, keyNms ...string) (keyInfo map[string]map[string]any, err error) {
var totpNm string
var mut sync.Mutex
var wg sync.WaitGroup
var errChan chan error
var mErr *multierr.MultiError = multierr.NewMultiError(nil)
var opts []vault.RequestOption = make([]vault.RequestOption, 0)
var listResp *vault.Response[schema.StandardListResponse]
var respErr *vault.ResponseError = new(vault.ResponseError)
if strings.TrimSpace(mntPt) != "" {
opts = append(opts, vault.WithMountPath(mntPt))
}
if keyNms == nil || len(keyNms) == 0 {
if listResp, err = vc.Secrets.TotpListKeys(
ctx,
opts...,
); err != nil {
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
// Is OK; no keys exist yet.
keyInfo = make(map[string]map[string]any)
err = nil
}
return
}
keyNms = listResp.Data.Keys
}
keyInfo = make(map[string]map[string]any)
if keyNms == nil || len(keyNms) == 0 {
return
}
errChan = make(chan error, len(keyNms))
for _, totpNm = range keyNms {
wg.Add(1)
go getTotpKeyAsync(ctx, totpNm, mntPt, vc, &mut, errChan, &wg, keyInfo)
}
wg.Wait()
close(errChan)
for err = range errChan {
if err != nil {
mErr.AddError(err)
err = nil
}
}
if !mErr.IsEmpty() {
err = mErr
return
}
return
}
// GetVaultClient returns a Vault client from the provided args.
func GetVaultClient(args *VaultArgs) (c *vault.Client, err error) {
var tok string
var vc *vault.Client
var opts []vault.ClientOption
var vaultTls vault.TLSConfiguration
if args == nil {
err = errs.ErrNilVault
return
}
tok = args.Token
if err = GetVaultToken(&tok); err != nil {
return
}
opts = []vault.ClientOption{vault.WithAddress(args.Addr)}
if args.Insecure || args.SniName != nil {
vaultTls = vault.TLSConfiguration{
InsecureSkipVerify: args.Insecure,
}
if args.SniName != nil {
vaultTls.ServerName = *args.SniName
}
opts = append(
opts,
vault.WithTLS(vaultTls),
)
}
if vc, err = vault.New(opts...); err != nil {
return
}
if err = vc.SetToken(tok); err != nil {
return
}
c = vc
return
}
// GetVaultToken standardizes the Vault token fetching/lookup.
func GetVaultToken(tok *string) (err error) {
var p1 []byte
var oldState *term.State
var termFd int = int(os.Stdin.Fd())
if tok != nil && len(strings.TrimSpace(*tok)) > 0 {
return
}
if oldState, err = term.GetState(termFd); err != nil {
return
}
fmt.Println("Vault token needed.\nVault token (will not be echoed back):")
defer func() {
if err = term.Restore(termFd, oldState); err != nil {
log.Println("restore failed:", err)
}
}()
if p1, err = term.ReadPassword(termFd); err != nil {
return
}
if tok == nil {
tok = new(string)
}
*tok = string(p1)
return
}
/*
ListTotpKeys returns a list (map, really, for easier lookup) of TOTP key names at `mntpt`
with [github.com/hashicorp/vault-client-go.Client] `vc` and [context.Context] `ctx`.
If `mntpt` is empty, the default ("totp") will be used.
If no TOTP keys are found at the mount, `keyNms` will be empty but not nil.
See [ListTotpKeys] if you want additional information about each key.
*/
func ListTotpKeys(ctx context.Context, vc *vault.Client, mntPt string) (keyNms map[string]struct{}, err error) {
var totpNm string
var opts []vault.RequestOption = make([]vault.RequestOption, 0)
var listResp *vault.Response[schema.StandardListResponse]
var respErr *vault.ResponseError = new(vault.ResponseError)
if strings.TrimSpace(mntPt) != "" {
opts = append(opts, vault.WithMountPath(mntPt))
}
if listResp, err = vc.Secrets.TotpListKeys(
ctx,
opts...,
); err != nil {
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
// Is OK; no keys exist yet.
keyNms = make(map[string]struct{})
err = nil
} else {
return
}
} else {
keyNms = make(map[string]struct{})
for _, totpNm = range listResp.Data.Keys {
keyNms[totpNm] = struct{}{}
}
}
return
}
// PrepParser properly initializes the parser and logger in a standardized way.
func PrepParser(cmd string, args CommonArgs, p *flags.Parser) (doExit bool, err error) {
var logFlagsRuntime int = logFlags
var flagsErr *flags.Error = new(flags.Error)
p.NamespaceDelimiter = ParseNsDelim
p.EnvNamespaceDelimiter = ParseEnvNsDelim
if _, err = p.Parse(); err != nil {
switch {
case errors.As(err, &flagsErr):
switch {
case errors.Is(flagsErr.Type, flags.ErrHelp),
errors.Is(flagsErr.Type, flags.ErrCommandRequired),
errors.Is(flagsErr.Type, flags.ErrRequired):
// These print their relevant messages by themselves.
err = nil
return
default:
return
}
default:
return
}
}
if version.Ver, err = version.Version(); err != nil {
return
}
// If args.Version or args.DetailVersion are true, just print them and exit.
if args.DetailVersion || args.Version {
doExit = true
if args.Version {
fmt.Println(version.Ver.Short())
return
} else if args.DetailVersion {
fmt.Println(version.Ver.Detail())
return
}
}
if args.DoDebug {
logFlagsRuntime = logFlagsDebug
}
Logger = logging.GetMultiLogger(
args.DoDebug,
fmt.Sprintf(
"Vault TOTP [%s_%s]",
cmdPfx, cmd,
),
)
if err = Logger.AddDefaultLogger(
"default",
logFlagsRuntime,
"/var/log/vault_totp/vault_totp.log", "~/logs/vault_totp.log",
); err != nil {
log.Panicln(err)
}
if err = Logger.Setup(); err != nil {
log.Panicln(err)
}
Logger.Info("main: Vault TOTP version %v", version.Ver.Short())
Logger.Debug("main: Vault TOTP version (extended):\n%v", version.Ver.Detail())
return
}
// SplitVaultPathspec splits a <mount>[:<path>] into separate components. If no path is provided or it is empty, "/" will be used.
func SplitVaultPathspec(spec string) (mount, secretPath string) {
var spl []string = strings.SplitN(spec, ":", 2)
mount = spl[0]
switch len(spl) {
case 1:
secretPath = "/"
case 2:
if strings.TrimSpace(spl[1]) == "" {
secretPath = "/"
} else {
secretPath = spl[1]
}
}
return
}
// SplitVaultPathspec2 splits a [<mount>:]<path> into separate components.
func SplitVaultPathspec2(spec string) (mount, secretPath string) {
var spl []string = strings.SplitN(spec, ":", 2)
switch len(spl) {
case 1:
secretPath = spl[0]
case 2:
mount = spl[0]
if strings.TrimSpace(spl[1]) == "" {
secretPath = "/"
} else {
secretPath = spl[1]
}
}
return
}
// Validate validates the passed struct `s`. A nil err means validation succeeded.
func Validate(s any) (err error) {
if err = validate.Struct(s); err != nil {
return
}
return
}
// getTotpKeyAsync fetches key info for key named `nm` from Vault `vc`.
func getTotpKeyAsync(
ctx context.Context,
keyNm, mntPt string,
vc *vault.Client,
mut *sync.Mutex, errChan chan error, wg *sync.WaitGroup,
m map[string]map[string]any,
) {
var err error
var kinfo map[string]any
defer wg.Done()
if kinfo, err = GetTotpKey(ctx, keyNm, mntPt, vc); err != nil {
errChan <- err
return
}
// We can wait to hold the lock until we're actually inserting into the map.
mut.Lock()
defer mut.Unlock()
m[keyNm] = kinfo
return
}