package main import ( `context` `encoding/json` `fmt` `log` `net/http` `path` `strconv` `time` `github.com/davecgh/go-spew/spew` `github.com/hashicorp/vault-client-go` `github.com/pterm/pterm` `r00t2.io/vault_totp/common` `r00t2.io/vault_totp/errs` ) // displayCode fetches a code (getCode) and displays it (or rather, sets up the display for it) according to the settings from args global. func displayCode(keyNm string) { var err error var code string var cfg *otpCfg var itersLeft int var infinite bool var spinner pterm.SpinnerPrinter defer wg.Done() infinite = args.GenArgs.Repeat < 0 if args.GenArgs.NoCtr { if args.GenArgs.Repeat } if code, cfg, err = getCode(ctx, keyNm, args.GenArgs.VaultTotpMnt, args.GenArgs.Readable); err != nil { logger.Err("displayCode: Received error getting TOTP code and configuration: %v") return } return } /* expiryCb is a callback used to pull the ['Date' header] from the response. See [hashicorp/vault#31684], [openbao/openbao#2233]. ['Date' header]: https://datatracker.ietf.org/doc/html/rfc9110#section-6.6.1 [hashicorp/vault#31684]: https://github.com/hashicorp/vault/issues/31684 [openbao/openbao#2233]: https://github.com/openbao/openbao/issues/2233 */ func expiryCb(req *http.Request, resp *http.Response) { var err error var i any var n int var ok bool var keyNm string keyNm = path.Base(req.URL.Path) otpCfgs[keyNm] = &otpCfg{ keyNm: keyNm, respDate: time.Time{}, period: 0, timeStep: 0, expiry: time.Time{}, } if i, ok = kinfo["period"]; !ok { logger.Err("expiryCb: No period found for key '%s' in kinfo", keyNm) return } switch t := i.(type) { case string: // If it's an int string, it's seconds. if n, err = strconv.Atoi(t); err != nil { logger.Warning("expiryCb: Invalid period integer for key '%s': %#v: %v", keyNm, i, err) // It's not a pure int string, so try a time.Duration string (e.g. "30s"). if otpCfgs[keyNm].period, err = time.ParseDuration(t); err != nil { logger.Err("expiryCb: Invalid period duration for key '%s': %#v: %v", keyNm, i, err) return } } case json.Number: // But I think it's actually a json.Number... if n, err = strconv.Atoi(string(t)); err != nil { logger.Warning("expiryCb: Invalid period json.Number for key '%s': %#v: %v", i, keyNm, err) return } case int: n = t default: logger.Err("expiryCb: Invalid period type for key '%s' (%#T): %#v", keyNm, i, i) return } if otpCfgs[keyNm].period == 0 && n != 0 { // Golang is weird like this but basically time.Duration(n) isn't *actually* meaningful, // it's just necessary for type matching. otpCfgs[keyNm].period = time.Second * time.Duration(n) } else if n == 0 { logger.Err("expiryCb: Could not derive time primitive for key '%s' from '%#v'", keyNm, i) return } if otpCfgs[keyNm].respDate, err = http.ParseTime(resp.Header.Get("Date")); err != nil { logger.Err( "expiryCb: received error parsing 'Date' header ('%s') for key '%s': %v", resp.Header.Get("Date"), keyNm, err, ) return } otpCfgs[keyNm].timeStep = common.TimeStepFromTime(otpCfgs[keyNm].respDate, otpCfgs[keyNm].period) otpCfgs[keyNm].startTimeStep, otpCfgs[keyNm].expiry = common.TimeStepToTime(otpCfgs[keyNm].timeStep, otpCfgs[keyNm].period) logger.Debug("expiryCb: Derived expiration for key '%s':\n%s", keyNm, spew.Sdump(otpCfgs[keyNm])) } /* getCode gets an OTP code from Vault (and populates the corresponding cfg). The code will be split with a space if `readable` is true. */ func getCode(ctx context.Context, keyNm, mntPt string, readable bool) (code string, cfg *otpCfg, err error) { var i any var ok bool var resp *vault.Response[map[string]interface{}] if resp, err = vc.Secrets.TotpGenerateCode( ctx, keyNm, vault.WithMountPath(mntPt), vault.WithResponseCallbacks(expiryCb), ); err != nil { log.Panicln(err) } if i, ok = resp.Data["code"]; !ok { err = errs.ErrNoCode logger.Err("getCode: Key '%s': %v", keyNm, err) return } if cfg, ok = otpCfgs[keyNm]; !ok { err = errs.ErrNoCfg logger.Err("getCode: Key '%s': %v", keyNm, err) return } switch t := i.(type) { case string: code = t case json.Number: code = string(t) case int: code = strconv.Itoa(t) default: logger.Err("getCode: Invalid type for key '%s' (%#T): %#v", keyNm, i, i) err = errs.ErrNoCode return } if readable { switch len(code) { case 6: code = fmt.Sprintf("%s %s", code[:3], code[3:]) case 8: code = fmt.Sprintf("%s %s", code[:4], code[4:]) } } return }