checking in
This commit is contained in:
180
cmd/gen/funcs.go
Normal file
180
cmd/gen/funcs.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
`context`
|
||||
`encoding/json`
|
||||
`fmt`
|
||||
`log`
|
||||
`net/http`
|
||||
`path`
|
||||
`strconv`
|
||||
`time`
|
||||
|
||||
`github.com/davecgh/go-spew/spew`
|
||||
`github.com/hashicorp/vault-client-go`
|
||||
`github.com/pterm/pterm`
|
||||
`r00t2.io/vault_totp/common`
|
||||
`r00t2.io/vault_totp/errs`
|
||||
)
|
||||
|
||||
// displayCode fetches a code (getCode) and displays it (or rather, sets up the display for it) according to the settings from args global.
|
||||
func displayCode(keyNm string) {
|
||||
|
||||
var err error
|
||||
var code string
|
||||
var cfg *otpCfg
|
||||
var itersLeft int
|
||||
var infinite bool
|
||||
var spinner pterm.SpinnerPrinter
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
infinite = args.GenArgs.Repeat < 0
|
||||
|
||||
if args.GenArgs.NoCtr {
|
||||
if args.GenArgs.Repeat
|
||||
}
|
||||
|
||||
if code, cfg, err = getCode(ctx, keyNm, args.GenArgs.VaultTotpMnt, args.GenArgs.Readable); err != nil {
|
||||
logger.Err("displayCode: Received error getting TOTP code and configuration: %v")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
expiryCb is a callback used to pull the ['Date' header] from the response.
|
||||
|
||||
See [hashicorp/vault#31684], [openbao/openbao#2233].
|
||||
|
||||
['Date' header]: https://datatracker.ietf.org/doc/html/rfc9110#section-6.6.1
|
||||
[hashicorp/vault#31684]: https://github.com/hashicorp/vault/issues/31684
|
||||
[openbao/openbao#2233]: https://github.com/openbao/openbao/issues/2233
|
||||
*/
|
||||
func expiryCb(req *http.Request, resp *http.Response) {
|
||||
|
||||
var err error
|
||||
var i any
|
||||
var n int
|
||||
var ok bool
|
||||
var keyNm string
|
||||
|
||||
keyNm = path.Base(req.URL.Path)
|
||||
|
||||
otpCfgs[keyNm] = &otpCfg{
|
||||
keyNm: keyNm,
|
||||
respDate: time.Time{},
|
||||
period: 0,
|
||||
timeStep: 0,
|
||||
expiry: time.Time{},
|
||||
}
|
||||
|
||||
if i, ok = kinfo["period"]; !ok {
|
||||
logger.Err("expiryCb: No period found for key '%s' in kinfo", keyNm)
|
||||
return
|
||||
}
|
||||
|
||||
switch t := i.(type) {
|
||||
case string:
|
||||
// If it's an int string, it's seconds.
|
||||
if n, err = strconv.Atoi(t); err != nil {
|
||||
logger.Warning("expiryCb: Invalid period integer for key '%s': %#v: %v", keyNm, i, err)
|
||||
// It's not a pure int string, so try a time.Duration string (e.g. "30s").
|
||||
if otpCfgs[keyNm].period, err = time.ParseDuration(t); err != nil {
|
||||
logger.Err("expiryCb: Invalid period duration for key '%s': %#v: %v", keyNm, i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
case json.Number:
|
||||
// But I think it's actually a json.Number...
|
||||
if n, err = strconv.Atoi(string(t)); err != nil {
|
||||
logger.Warning("expiryCb: Invalid period json.Number for key '%s': %#v: %v", i, keyNm, err)
|
||||
return
|
||||
}
|
||||
case int:
|
||||
n = t
|
||||
default:
|
||||
logger.Err("expiryCb: Invalid period type for key '%s' (%#T): %#v", keyNm, i, i)
|
||||
return
|
||||
}
|
||||
if otpCfgs[keyNm].period == 0 && n != 0 {
|
||||
// Golang is weird like this but basically time.Duration(n) isn't *actually* meaningful,
|
||||
// it's just necessary for type matching.
|
||||
otpCfgs[keyNm].period = time.Second * time.Duration(n)
|
||||
} else if n == 0 {
|
||||
logger.Err("expiryCb: Could not derive time primitive for key '%s' from '%#v'", keyNm, i)
|
||||
return
|
||||
}
|
||||
|
||||
if otpCfgs[keyNm].respDate, err = http.ParseTime(resp.Header.Get("Date")); err != nil {
|
||||
logger.Err(
|
||||
"expiryCb: received error parsing 'Date' header ('%s') for key '%s': %v",
|
||||
resp.Header.Get("Date"), keyNm, err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
otpCfgs[keyNm].timeStep = common.TimeStepFromTime(otpCfgs[keyNm].respDate, otpCfgs[keyNm].period)
|
||||
otpCfgs[keyNm].startTimeStep, otpCfgs[keyNm].expiry = common.TimeStepToTime(otpCfgs[keyNm].timeStep, otpCfgs[keyNm].period)
|
||||
|
||||
logger.Debug("expiryCb: Derived expiration for key '%s':\n%s", keyNm, spew.Sdump(otpCfgs[keyNm]))
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
getCode gets an OTP code from Vault (and populates the corresponding cfg).
|
||||
|
||||
The code will be split with a space if `readable` is true.
|
||||
*/
|
||||
func getCode(ctx context.Context, keyNm, mntPt string, readable bool) (code string, cfg *otpCfg, err error) {
|
||||
|
||||
var i any
|
||||
var ok bool
|
||||
var resp *vault.Response[map[string]interface{}]
|
||||
|
||||
if resp, err = vc.Secrets.TotpGenerateCode(
|
||||
ctx,
|
||||
keyNm,
|
||||
vault.WithMountPath(mntPt),
|
||||
vault.WithResponseCallbacks(expiryCb),
|
||||
); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
|
||||
if i, ok = resp.Data["code"]; !ok {
|
||||
err = errs.ErrNoCode
|
||||
logger.Err("getCode: Key '%s': %v", keyNm, err)
|
||||
return
|
||||
}
|
||||
|
||||
if cfg, ok = otpCfgs[keyNm]; !ok {
|
||||
err = errs.ErrNoCfg
|
||||
logger.Err("getCode: Key '%s': %v", keyNm, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch t := i.(type) {
|
||||
case string:
|
||||
code = t
|
||||
case json.Number:
|
||||
code = string(t)
|
||||
case int:
|
||||
code = strconv.Itoa(t)
|
||||
default:
|
||||
logger.Err("getCode: Invalid type for key '%s' (%#T): %#v", keyNm, i, i)
|
||||
err = errs.ErrNoCode
|
||||
return
|
||||
}
|
||||
|
||||
if readable {
|
||||
switch len(code) {
|
||||
case 6:
|
||||
code = fmt.Sprintf("%s %s", code[:3], code[3:])
|
||||
case 8:
|
||||
code = fmt.Sprintf("%s %s", code[:4], code[4:])
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user