1
0
Files
vault_totp/cmd/gen/funcs.go
2025-12-23 20:58:56 -05:00

181 lines
4.5 KiB
Go

package main
import (
`context`
`encoding/json`
`fmt`
`log`
`net/http`
`path`
`strconv`
`time`
`github.com/davecgh/go-spew/spew`
`github.com/hashicorp/vault-client-go`
`github.com/pterm/pterm`
`r00t2.io/vault_totp/common`
`r00t2.io/vault_totp/errs`
)
// displayCode fetches a code (getCode) and displays it (or rather, sets up the display for it) according to the settings from args global.
func displayCode(keyNm string) {
var err error
var code string
var cfg *otpCfg
var itersLeft int
var infinite bool
var spinner pterm.SpinnerPrinter
defer wg.Done()
infinite = args.GenArgs.Repeat < 0
if args.GenArgs.NoCtr {
if args.GenArgs.Repeat
}
if code, cfg, err = getCode(ctx, keyNm, args.GenArgs.VaultTotpMnt, args.GenArgs.Readable); err != nil {
logger.Err("displayCode: Received error getting TOTP code and configuration: %v")
return
}
return
}
/*
expiryCb is a callback used to pull the ['Date' header] from the response.
See [hashicorp/vault#31684], [openbao/openbao#2233].
['Date' header]: https://datatracker.ietf.org/doc/html/rfc9110#section-6.6.1
[hashicorp/vault#31684]: https://github.com/hashicorp/vault/issues/31684
[openbao/openbao#2233]: https://github.com/openbao/openbao/issues/2233
*/
func expiryCb(req *http.Request, resp *http.Response) {
var err error
var i any
var n int
var ok bool
var keyNm string
keyNm = path.Base(req.URL.Path)
otpCfgs[keyNm] = &otpCfg{
keyNm: keyNm,
respDate: time.Time{},
period: 0,
timeStep: 0,
expiry: time.Time{},
}
if i, ok = kinfo["period"]; !ok {
logger.Err("expiryCb: No period found for key '%s' in kinfo", keyNm)
return
}
switch t := i.(type) {
case string:
// If it's an int string, it's seconds.
if n, err = strconv.Atoi(t); err != nil {
logger.Warning("expiryCb: Invalid period integer for key '%s': %#v: %v", keyNm, i, err)
// It's not a pure int string, so try a time.Duration string (e.g. "30s").
if otpCfgs[keyNm].period, err = time.ParseDuration(t); err != nil {
logger.Err("expiryCb: Invalid period duration for key '%s': %#v: %v", keyNm, i, err)
return
}
}
case json.Number:
// But I think it's actually a json.Number...
if n, err = strconv.Atoi(string(t)); err != nil {
logger.Warning("expiryCb: Invalid period json.Number for key '%s': %#v: %v", i, keyNm, err)
return
}
case int:
n = t
default:
logger.Err("expiryCb: Invalid period type for key '%s' (%#T): %#v", keyNm, i, i)
return
}
if otpCfgs[keyNm].period == 0 && n != 0 {
// Golang is weird like this but basically time.Duration(n) isn't *actually* meaningful,
// it's just necessary for type matching.
otpCfgs[keyNm].period = time.Second * time.Duration(n)
} else if n == 0 {
logger.Err("expiryCb: Could not derive time primitive for key '%s' from '%#v'", keyNm, i)
return
}
if otpCfgs[keyNm].respDate, err = http.ParseTime(resp.Header.Get("Date")); err != nil {
logger.Err(
"expiryCb: received error parsing 'Date' header ('%s') for key '%s': %v",
resp.Header.Get("Date"), keyNm, err,
)
return
}
otpCfgs[keyNm].timeStep = common.TimeStepFromTime(otpCfgs[keyNm].respDate, otpCfgs[keyNm].period)
otpCfgs[keyNm].startTimeStep, otpCfgs[keyNm].expiry = common.TimeStepToTime(otpCfgs[keyNm].timeStep, otpCfgs[keyNm].period)
logger.Debug("expiryCb: Derived expiration for key '%s':\n%s", keyNm, spew.Sdump(otpCfgs[keyNm]))
}
/*
getCode gets an OTP code from Vault (and populates the corresponding cfg).
The code will be split with a space if `readable` is true.
*/
func getCode(ctx context.Context, keyNm, mntPt string, readable bool) (code string, cfg *otpCfg, err error) {
var i any
var ok bool
var resp *vault.Response[map[string]interface{}]
if resp, err = vc.Secrets.TotpGenerateCode(
ctx,
keyNm,
vault.WithMountPath(mntPt),
vault.WithResponseCallbacks(expiryCb),
); err != nil {
log.Panicln(err)
}
if i, ok = resp.Data["code"]; !ok {
err = errs.ErrNoCode
logger.Err("getCode: Key '%s': %v", keyNm, err)
return
}
if cfg, ok = otpCfgs[keyNm]; !ok {
err = errs.ErrNoCfg
logger.Err("getCode: Key '%s': %v", keyNm, err)
return
}
switch t := i.(type) {
case string:
code = t
case json.Number:
code = string(t)
case int:
code = strconv.Itoa(t)
default:
logger.Err("getCode: Invalid type for key '%s' (%#T): %#v", keyNm, i, i)
err = errs.ErrNoCode
return
}
if readable {
switch len(code) {
case 6:
code = fmt.Sprintf("%s %s", code[:3], code[3:])
case 8:
code = fmt.Sprintf("%s %s", code[:4], code[4:])
}
}
return
}