181 lines
4.5 KiB
Go
181 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
`context`
|
|
`encoding/json`
|
|
`fmt`
|
|
`log`
|
|
`net/http`
|
|
`path`
|
|
`strconv`
|
|
`time`
|
|
|
|
`github.com/davecgh/go-spew/spew`
|
|
`github.com/hashicorp/vault-client-go`
|
|
`github.com/pterm/pterm`
|
|
`r00t2.io/vault_totp/common`
|
|
`r00t2.io/vault_totp/errs`
|
|
)
|
|
|
|
// displayCode fetches a code (getCode) and displays it (or rather, sets up the display for it) according to the settings from args global.
|
|
func displayCode(keyNm string) {
|
|
|
|
var err error
|
|
var code string
|
|
var cfg *otpCfg
|
|
var itersLeft int
|
|
var infinite bool
|
|
var spinner pterm.SpinnerPrinter
|
|
|
|
defer wg.Done()
|
|
|
|
infinite = args.GenArgs.Repeat < 0
|
|
|
|
if args.GenArgs.NoCtr {
|
|
if args.GenArgs.Repeat
|
|
}
|
|
|
|
if code, cfg, err = getCode(ctx, keyNm, args.GenArgs.VaultTotpMnt, args.GenArgs.Readable); err != nil {
|
|
logger.Err("displayCode: Received error getting TOTP code and configuration: %v")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
/*
|
|
expiryCb is a callback used to pull the ['Date' header] from the response.
|
|
|
|
See [hashicorp/vault#31684], [openbao/openbao#2233].
|
|
|
|
['Date' header]: https://datatracker.ietf.org/doc/html/rfc9110#section-6.6.1
|
|
[hashicorp/vault#31684]: https://github.com/hashicorp/vault/issues/31684
|
|
[openbao/openbao#2233]: https://github.com/openbao/openbao/issues/2233
|
|
*/
|
|
func expiryCb(req *http.Request, resp *http.Response) {
|
|
|
|
var err error
|
|
var i any
|
|
var n int
|
|
var ok bool
|
|
var keyNm string
|
|
|
|
keyNm = path.Base(req.URL.Path)
|
|
|
|
otpCfgs[keyNm] = &otpCfg{
|
|
keyNm: keyNm,
|
|
respDate: time.Time{},
|
|
period: 0,
|
|
timeStep: 0,
|
|
expiry: time.Time{},
|
|
}
|
|
|
|
if i, ok = kinfo["period"]; !ok {
|
|
logger.Err("expiryCb: No period found for key '%s' in kinfo", keyNm)
|
|
return
|
|
}
|
|
|
|
switch t := i.(type) {
|
|
case string:
|
|
// If it's an int string, it's seconds.
|
|
if n, err = strconv.Atoi(t); err != nil {
|
|
logger.Warning("expiryCb: Invalid period integer for key '%s': %#v: %v", keyNm, i, err)
|
|
// It's not a pure int string, so try a time.Duration string (e.g. "30s").
|
|
if otpCfgs[keyNm].period, err = time.ParseDuration(t); err != nil {
|
|
logger.Err("expiryCb: Invalid period duration for key '%s': %#v: %v", keyNm, i, err)
|
|
return
|
|
}
|
|
}
|
|
case json.Number:
|
|
// But I think it's actually a json.Number...
|
|
if n, err = strconv.Atoi(string(t)); err != nil {
|
|
logger.Warning("expiryCb: Invalid period json.Number for key '%s': %#v: %v", i, keyNm, err)
|
|
return
|
|
}
|
|
case int:
|
|
n = t
|
|
default:
|
|
logger.Err("expiryCb: Invalid period type for key '%s' (%#T): %#v", keyNm, i, i)
|
|
return
|
|
}
|
|
if otpCfgs[keyNm].period == 0 && n != 0 {
|
|
// Golang is weird like this but basically time.Duration(n) isn't *actually* meaningful,
|
|
// it's just necessary for type matching.
|
|
otpCfgs[keyNm].period = time.Second * time.Duration(n)
|
|
} else if n == 0 {
|
|
logger.Err("expiryCb: Could not derive time primitive for key '%s' from '%#v'", keyNm, i)
|
|
return
|
|
}
|
|
|
|
if otpCfgs[keyNm].respDate, err = http.ParseTime(resp.Header.Get("Date")); err != nil {
|
|
logger.Err(
|
|
"expiryCb: received error parsing 'Date' header ('%s') for key '%s': %v",
|
|
resp.Header.Get("Date"), keyNm, err,
|
|
)
|
|
return
|
|
}
|
|
|
|
otpCfgs[keyNm].timeStep = common.TimeStepFromTime(otpCfgs[keyNm].respDate, otpCfgs[keyNm].period)
|
|
otpCfgs[keyNm].startTimeStep, otpCfgs[keyNm].expiry = common.TimeStepToTime(otpCfgs[keyNm].timeStep, otpCfgs[keyNm].period)
|
|
|
|
logger.Debug("expiryCb: Derived expiration for key '%s':\n%s", keyNm, spew.Sdump(otpCfgs[keyNm]))
|
|
|
|
}
|
|
|
|
/*
|
|
getCode gets an OTP code from Vault (and populates the corresponding cfg).
|
|
|
|
The code will be split with a space if `readable` is true.
|
|
*/
|
|
func getCode(ctx context.Context, keyNm, mntPt string, readable bool) (code string, cfg *otpCfg, err error) {
|
|
|
|
var i any
|
|
var ok bool
|
|
var resp *vault.Response[map[string]interface{}]
|
|
|
|
if resp, err = vc.Secrets.TotpGenerateCode(
|
|
ctx,
|
|
keyNm,
|
|
vault.WithMountPath(mntPt),
|
|
vault.WithResponseCallbacks(expiryCb),
|
|
); err != nil {
|
|
log.Panicln(err)
|
|
}
|
|
|
|
if i, ok = resp.Data["code"]; !ok {
|
|
err = errs.ErrNoCode
|
|
logger.Err("getCode: Key '%s': %v", keyNm, err)
|
|
return
|
|
}
|
|
|
|
if cfg, ok = otpCfgs[keyNm]; !ok {
|
|
err = errs.ErrNoCfg
|
|
logger.Err("getCode: Key '%s': %v", keyNm, err)
|
|
return
|
|
}
|
|
|
|
switch t := i.(type) {
|
|
case string:
|
|
code = t
|
|
case json.Number:
|
|
code = string(t)
|
|
case int:
|
|
code = strconv.Itoa(t)
|
|
default:
|
|
logger.Err("getCode: Invalid type for key '%s' (%#T): %#v", keyNm, i, i)
|
|
err = errs.ErrNoCode
|
|
return
|
|
}
|
|
|
|
if readable {
|
|
switch len(code) {
|
|
case 6:
|
|
code = fmt.Sprintf("%s %s", code[:3], code[3:])
|
|
case 8:
|
|
code = fmt.Sprintf("%s %s", code[:4], code[4:])
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|