checking in
This commit is contained in:
@@ -2,102 +2,371 @@ package internal
|
||||
|
||||
import (
|
||||
`context`
|
||||
`errors`
|
||||
`fmt`
|
||||
`log`
|
||||
`os`
|
||||
`strings`
|
||||
`sync`
|
||||
|
||||
`github.com/hashicorp/vault-client-go`
|
||||
`r00t2.io/gosecret`
|
||||
`github.com/hashicorp/vault-client-go/schema`
|
||||
`github.com/jessevdk/go-flags`
|
||||
`golang.org/x/term`
|
||||
`r00t2.io/goutils/logging`
|
||||
`r00t2.io/goutils/multierr`
|
||||
`r00t2.io/vault_totp/errs`
|
||||
`r00t2.io/vault_totp/version`
|
||||
)
|
||||
|
||||
func New(vaultTok, vaultAddr, vaultMnt, collNm string) (c *Client, err error) {
|
||||
/*
|
||||
GetTotpKey fetches the key info for a key named `keyNm` at TOTP secrets mountpoint `mntPt` using Vault client `vc`.
|
||||
|
||||
c = &Client{
|
||||
// lastIdx: 0,
|
||||
vtok: vaultTok,
|
||||
vaddr: vaultAddr,
|
||||
scollNm: collNm,
|
||||
vmnt: vaultMnt,
|
||||
errsDone: make(chan bool, 1),
|
||||
errChan: make(chan error),
|
||||
// vc: nil,
|
||||
wg: sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
// ssvc: nil,
|
||||
// scoll: nil,
|
||||
mErr: multierr.NewMultiError(nil),
|
||||
// inSS: nil,
|
||||
// inVault: nil,
|
||||
If `mntPt` is empty, it will be set to "totp".
|
||||
*/
|
||||
func GetTotpKey(ctx context.Context, keyNm, mntPt string, vc *vault.Client) (kinfo map[string]any, err error) {
|
||||
|
||||
var resp *vault.Response[map[string]interface{}]
|
||||
|
||||
if strings.TrimSpace(mntPt) == "" {
|
||||
mntPt = "totp"
|
||||
}
|
||||
|
||||
if c.vc, err = vault.New(vault.WithAddress(c.vaddr)); err != nil {
|
||||
return
|
||||
}
|
||||
if err = c.vc.SetToken(c.vtok); err != nil {
|
||||
if resp, err = vc.Secrets.TotpReadKey(
|
||||
ctx,
|
||||
keyNm,
|
||||
vault.WithMountPath(mntPt),
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.ssvc, err = gosecret.NewService(); err != nil {
|
||||
return
|
||||
}
|
||||
if c.scoll, err = c.ssvc.GetCollection(collNm); err != nil {
|
||||
return
|
||||
kinfo = resp.Data
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
GetTotpKeys is like [ListTotpKeys] except it returns the configuration info
|
||||
for each key as well. (Except the secret - https://github.com/hashicorp/vault/issues/3043)
|
||||
|
||||
keyNms, if not specified, will fetch info for all keys on the mountpoint `mntPt`.
|
||||
*/
|
||||
func GetTotpKeys(ctx context.Context, vc *vault.Client, mntPt string, keyNms ...string) (keyInfo map[string]map[string]any, err error) {
|
||||
|
||||
var totpNm string
|
||||
var mut sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var errChan chan error
|
||||
var mErr *multierr.MultiError = multierr.NewMultiError(nil)
|
||||
var opts []vault.RequestOption = make([]vault.RequestOption, 0)
|
||||
var listResp *vault.Response[schema.StandardListResponse]
|
||||
var respErr *vault.ResponseError = new(vault.ResponseError)
|
||||
|
||||
if strings.TrimSpace(mntPt) != "" {
|
||||
opts = append(opts, vault.WithMountPath(mntPt))
|
||||
}
|
||||
|
||||
go c.readErrs()
|
||||
if keyNms == nil || len(keyNms) == 0 {
|
||||
if listResp, err = vc.Secrets.TotpListKeys(
|
||||
ctx,
|
||||
opts...,
|
||||
); err != nil {
|
||||
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
|
||||
// Is OK; no keys exist yet.
|
||||
keyInfo = make(map[string]map[string]any)
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
keyNms = listResp.Data.Keys
|
||||
}
|
||||
|
||||
c.wg.Add(2)
|
||||
go c.getSS()
|
||||
go c.getVault()
|
||||
keyInfo = make(map[string]map[string]any)
|
||||
if keyNms == nil || len(keyNms) == 0 {
|
||||
return
|
||||
}
|
||||
errChan = make(chan error, len(keyNms))
|
||||
for _, totpNm = range keyNms {
|
||||
wg.Add(1)
|
||||
go getTotpKeyAsync(ctx, totpNm, mntPt, vc, &mut, errChan, &wg, keyInfo)
|
||||
}
|
||||
|
||||
c.wg.Wait()
|
||||
|
||||
if !c.mErr.IsEmpty() {
|
||||
err = c.mErr
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
for err = range errChan {
|
||||
if err != nil {
|
||||
mErr.AddError(err)
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if !mErr.IsEmpty() {
|
||||
err = mErr
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func normalizeVaultNm(nm string) (normalized string) {
|
||||
// GetVaultClient returns a Vault client from the provided args.
|
||||
func GetVaultClient(args *VaultArgs) (c *vault.Client, err error) {
|
||||
|
||||
var c rune
|
||||
var idx int
|
||||
var last rune
|
||||
var repl rune = '_'
|
||||
var reduced []rune = make([]rune, 0)
|
||||
var norm []rune = make([]rune, 0, len(nm))
|
||||
var tok string
|
||||
var vc *vault.Client
|
||||
var opts []vault.ClientOption
|
||||
var vaultTls vault.TLSConfiguration
|
||||
|
||||
for _, c = range nm {
|
||||
// If it's "safe" chars, it's fine
|
||||
if (c == '-' || c == '.') || // 0x2d, 0x2e
|
||||
(c >= '0' && c <= '9') || // 0x30 to 0x39
|
||||
(c == '@') || // 0x40
|
||||
(c >= 'A' && c <= 'Z') || // 0x41 to 0x5a
|
||||
(c == '_') || // 0x5f
|
||||
(c >= 'a' && c <= 'z') { // 0x61 to 0x7a
|
||||
norm = append(norm, c)
|
||||
continue
|
||||
}
|
||||
// Otherwise normalize it to a safe char
|
||||
norm = append(norm, repl)
|
||||
if args == nil {
|
||||
err = errs.ErrNilVault
|
||||
return
|
||||
}
|
||||
|
||||
// And remove repeating sequential replacers.
|
||||
for idx, c = range norm[:] {
|
||||
if idx == 0 {
|
||||
last = c
|
||||
reduced = append(reduced, c)
|
||||
continue
|
||||
}
|
||||
if c == last && last == repl {
|
||||
continue
|
||||
}
|
||||
reduced = append(reduced, c)
|
||||
last = c
|
||||
tok = args.Token
|
||||
if err = GetVaultToken(&tok); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
normalized = string(reduced)
|
||||
opts = []vault.ClientOption{vault.WithAddress(args.Addr)}
|
||||
if args.Insecure || args.SniName != nil {
|
||||
vaultTls = vault.TLSConfiguration{
|
||||
InsecureSkipVerify: args.Insecure,
|
||||
}
|
||||
if args.SniName != nil {
|
||||
vaultTls.ServerName = *args.SniName
|
||||
}
|
||||
opts = append(
|
||||
opts,
|
||||
vault.WithTLS(vaultTls),
|
||||
)
|
||||
}
|
||||
|
||||
if vc, err = vault.New(opts...); err != nil {
|
||||
return
|
||||
}
|
||||
if err = vc.SetToken(tok); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c = vc
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetVaultToken standardizes the Vault token fetching/lookup.
|
||||
func GetVaultToken(tok *string) (err error) {
|
||||
|
||||
var p1 []byte
|
||||
var oldState *term.State
|
||||
var termFd int = int(os.Stdin.Fd())
|
||||
|
||||
if tok != nil && len(strings.TrimSpace(*tok)) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if oldState, err = term.GetState(termFd); err != nil {
|
||||
return
|
||||
}
|
||||
fmt.Println("Vault token needed.\nVault token (will not be echoed back):")
|
||||
defer func() {
|
||||
if err = term.Restore(termFd, oldState); err != nil {
|
||||
log.Println("restore failed:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if p1, err = term.ReadPassword(termFd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if tok == nil {
|
||||
tok = new(string)
|
||||
}
|
||||
*tok = string(p1)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
ListTotpKeys returns a list (map, really, for easier lookup) of TOTP key names at `mntpt`
|
||||
with [github.com/hashicorp/vault-client-go.Client] `vc` and [context.Context] `ctx`.
|
||||
If `mntpt` is empty, the default ("totp") will be used.
|
||||
|
||||
If no TOTP keys are found at the mount, `keyNms` will be empty but not nil.
|
||||
|
||||
See [ListTotpKeys] if you want additional information about each key.
|
||||
*/
|
||||
func ListTotpKeys(ctx context.Context, vc *vault.Client, mntPt string) (keyNms map[string]struct{}, err error) {
|
||||
|
||||
var totpNm string
|
||||
var opts []vault.RequestOption = make([]vault.RequestOption, 0)
|
||||
var listResp *vault.Response[schema.StandardListResponse]
|
||||
var respErr *vault.ResponseError = new(vault.ResponseError)
|
||||
|
||||
if strings.TrimSpace(mntPt) != "" {
|
||||
opts = append(opts, vault.WithMountPath(mntPt))
|
||||
}
|
||||
|
||||
if listResp, err = vc.Secrets.TotpListKeys(
|
||||
ctx,
|
||||
opts...,
|
||||
); err != nil {
|
||||
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
|
||||
// Is OK; no keys exist yet.
|
||||
keyNms = make(map[string]struct{})
|
||||
err = nil
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
keyNms = make(map[string]struct{})
|
||||
for _, totpNm = range listResp.Data.Keys {
|
||||
keyNms[totpNm] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// PrepParser properly initializes the parser and logger in a standardized way.
|
||||
func PrepParser(cmd string, args CommonArgs, p *flags.Parser) (doExit bool, err error) {
|
||||
|
||||
var logFlagsRuntime int = logFlags
|
||||
var flagsErr *flags.Error = new(flags.Error)
|
||||
|
||||
p.NamespaceDelimiter = ParseNsDelim
|
||||
p.EnvNamespaceDelimiter = ParseEnvNsDelim
|
||||
|
||||
if _, err = p.Parse(); err != nil {
|
||||
switch {
|
||||
case errors.As(err, &flagsErr):
|
||||
switch {
|
||||
case errors.Is(flagsErr.Type, flags.ErrHelp),
|
||||
errors.Is(flagsErr.Type, flags.ErrCommandRequired),
|
||||
errors.Is(flagsErr.Type, flags.ErrRequired):
|
||||
// These print their relevant messages by themselves.
|
||||
err = nil
|
||||
return
|
||||
default:
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if version.Ver, err = version.Version(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// If args.Version or args.DetailVersion are true, just print them and exit.
|
||||
if args.DetailVersion || args.Version {
|
||||
doExit = true
|
||||
if args.Version {
|
||||
fmt.Println(version.Ver.Short())
|
||||
return
|
||||
} else if args.DetailVersion {
|
||||
fmt.Println(version.Ver.Detail())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if args.DoDebug {
|
||||
logFlagsRuntime = logFlagsDebug
|
||||
}
|
||||
Logger = logging.GetMultiLogger(
|
||||
args.DoDebug,
|
||||
fmt.Sprintf(
|
||||
"Vault TOTP [%s_%s]",
|
||||
cmdPfx, cmd,
|
||||
),
|
||||
)
|
||||
if err = Logger.AddDefaultLogger(
|
||||
"default",
|
||||
logFlagsRuntime,
|
||||
"/var/log/vault_totp/vault_totp.log", "~/logs/vault_totp.log",
|
||||
); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
if err = Logger.Setup(); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
Logger.Info("main: Vault TOTP version %v", version.Ver.Short())
|
||||
Logger.Debug("main: Vault TOTP version (extended):\n%v", version.Ver.Detail())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SplitVaultPathspec splits a <mount>[:<path>] into separate components. If no path is provided or it is empty, "/" will be used.
|
||||
func SplitVaultPathspec(spec string) (mount, secretPath string) {
|
||||
|
||||
var spl []string = strings.SplitN(spec, ":", 2)
|
||||
|
||||
mount = spl[0]
|
||||
switch len(spl) {
|
||||
case 1:
|
||||
secretPath = "/"
|
||||
case 2:
|
||||
if strings.TrimSpace(spl[1]) == "" {
|
||||
secretPath = "/"
|
||||
} else {
|
||||
secretPath = spl[1]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SplitVaultPathspec2 splits a [<mount>:]<path> into separate components.
|
||||
func SplitVaultPathspec2(spec string) (mount, secretPath string) {
|
||||
|
||||
var spl []string = strings.SplitN(spec, ":", 2)
|
||||
|
||||
switch len(spl) {
|
||||
case 1:
|
||||
secretPath = spl[0]
|
||||
case 2:
|
||||
mount = spl[0]
|
||||
if strings.TrimSpace(spl[1]) == "" {
|
||||
secretPath = "/"
|
||||
} else {
|
||||
secretPath = spl[1]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Validate validates the passed struct `s`. A nil err means validation succeeded.
|
||||
func Validate(s any) (err error) {
|
||||
|
||||
if err = validate.Struct(s); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// getTotpKeyAsync fetches key info for key named `nm` from Vault `vc`.
|
||||
func getTotpKeyAsync(
|
||||
ctx context.Context,
|
||||
keyNm, mntPt string,
|
||||
vc *vault.Client,
|
||||
mut *sync.Mutex, errChan chan error, wg *sync.WaitGroup,
|
||||
m map[string]map[string]any,
|
||||
) {
|
||||
|
||||
var err error
|
||||
var kinfo map[string]any
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
if kinfo, err = GetTotpKey(ctx, keyNm, mntPt, vc); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// We can wait to hold the lock until we're actually inserting into the map.
|
||||
mut.Lock()
|
||||
defer mut.Unlock()
|
||||
m[keyNm] = kinfo
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user