checking in
This commit is contained in:
23
cmd/gen/args.go
Normal file
23
cmd/gen/args.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"r00t2.io/vault_totp/internal"
|
||||
)
|
||||
|
||||
type (
|
||||
Args struct {
|
||||
internal.CommonArgs
|
||||
GenArgs
|
||||
}
|
||||
GenArgs struct {
|
||||
VaultTotpMnt string `env:"VTOTP_MNT" short:"m" long:"mount" default:"totp" description:"The Vault TOTP generator mount (a 'TOTP secrets' mount) to fetch a code for -k/--key from (or list key names from)."`
|
||||
KeyNm []string `env:"VTOTP_GENK" short:"k" long:"key" description:"Key name(s) to generate code(s) for."`
|
||||
Readable bool `env:"VTOTP_RD" short:"R" long:"readable" description:"If specified, the code will be spaced out to be a bit more readable."`
|
||||
Repeat int `env:"VTOTP_RPT" short:"r" long:"repeat" description:"If non-zero, repeat code generation this many times. A negative number will generate codes indefinitely until ctrl-c is pressed/the program is killed. The default is to not repeat (0 == a single code, 1 == 2 codes generated, etc.)."`
|
||||
// The below are hidden (hidden:"true") until either https://github.com/hashicorp/vault/issues/31684 and/or https://github.com/openbao/openbao/issues/2233
|
||||
// I kind of fudge it for now. TODO.
|
||||
NoCtr bool `hidden:"true" env:"VTOTP_NOCTR" short:"q" long:"no-ctr" description:"If specified, do not perform a countdown of validity; just print the generated code to the terminal and exit immediately after."`
|
||||
PrintExpiry bool `hidden:"true" env:"VTOTP_EXPIRY" short:"e" long:"expiry" description:"If -q/--no-ctr is specified, also print the validity duration and expiration time (but do not animate a countdown, just print the validity/expiration and code and exit). The validity is always printed if a counter is."`
|
||||
Plain []bool `hidden:"true" env:"VTOTP_PLN" short:"p" long:"plain" description:"If specified, use a countdown timer more friendly to non-unicode terminals. Can be repeated for to three levels of increasing 'plain-ness'. Has no effect if -q/--no-ctr is specified. (Level 3 plainness is restricted to ASCII; no UTF-8.)"`
|
||||
}
|
||||
)
|
||||
90
cmd/gen/consts.go
Normal file
90
cmd/gen/consts.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
`context`
|
||||
`os`
|
||||
`sync`
|
||||
|
||||
`github.com/hashicorp/vault-client-go`
|
||||
`github.com/jessevdk/go-flags`
|
||||
`github.com/pterm/pterm`
|
||||
`r00t2.io/goutils/logging`
|
||||
)
|
||||
|
||||
var (
|
||||
logger logging.Logger
|
||||
args *Args = new(Args)
|
||||
parser *flags.Parser = flags.NewParser(args, flags.Default)
|
||||
)
|
||||
|
||||
var (
|
||||
vc *vault.Client
|
||||
wg sync.WaitGroup
|
||||
cancelFunc context.CancelFunc
|
||||
ctx context.Context = context.Background()
|
||||
// keyed on vault TOTP key name
|
||||
otpCfgs map[string]*otpCfg = make(map[string]*otpCfg)
|
||||
// keyed on Vault TOTP key name
|
||||
kinfo map[string]map[string]any = make(map[string]map[string]any)
|
||||
)
|
||||
|
||||
var (
|
||||
// This looks too goofy with the half-hour face.
|
||||
// It'd be better if I could do it in 15m increments,
|
||||
// but that doesn't exist in UTF-8 - half-hour-past is the closest fidelity we get...
|
||||
clocksFull []string = []string{
|
||||
"🕛", "🕧", "🕐", "🕜",
|
||||
"🕑", "🕝", "🕒", "🕞",
|
||||
"🕓", "🕟", "🕔", "🕠",
|
||||
"🕕", "🕡", "🕖", "🕢",
|
||||
"🕗", "🕣", "🕘", "🕤",
|
||||
"🕙", "🕥", "🕚", "🕦",
|
||||
}
|
||||
|
||||
// So do it hourly instead.
|
||||
clocksHourly []string = []string{
|
||||
"🕛", "🕐", "🕑", "🕒",
|
||||
"🕓", "🕔", "🕕", "🕖",
|
||||
"🕗", "🕘", "🕙", "🕚",
|
||||
}
|
||||
|
||||
// Level 1 "plain" (the clocks are level 0)
|
||||
plain1 []string = []string{
|
||||
"◷", "◶", "◵", "◴",
|
||||
}
|
||||
plain2 []string = []string{
|
||||
"⠈⠁", "⠈⠑", "⠈⠱", "⠈⡱", "⢀⡱", "⢄⡱", "⢄⡱", "⢆⡱", "⢎⡱",
|
||||
}
|
||||
// plain3 should restrict to pure ASCII
|
||||
plain3 []string = []string{
|
||||
"|", "/", "-", "\\", "-", "|",
|
||||
}
|
||||
|
||||
// indexed on level of plainness
|
||||
spinnerChars [][]string = [][]string{
|
||||
clocksHourly,
|
||||
plain1,
|
||||
plain2,
|
||||
plain3,
|
||||
}
|
||||
|
||||
// charSet is set to one of spinnerChars depending on args.GenArgs.Plain level.
|
||||
charSet []string
|
||||
|
||||
/*
|
||||
Previously I was going to use https://pkg.go.dev/github.com/chelnak/ysmrr for this.
|
||||
Namely because I *thought* it was the only spinner lib that can do multiple spinners at once.
|
||||
(Even submitted a PR, https://github.com/chelnak/ysmrr/pull/88)
|
||||
|
||||
HOWEVER, all spinners are synced to use the same animation... which means they all use the same
|
||||
frequency/rate of update.
|
||||
I want to sync it so the "animation" ends when the TOTP runs out (or thereabouts).
|
||||
|
||||
pterm to the rescue.
|
||||
*/
|
||||
codeMulti pterm.MultiPrinter = pterm.MultiPrinter{
|
||||
IsActive: false,
|
||||
Writer: os.Stdout,
|
||||
UpdateDelay: 0,
|
||||
}
|
||||
)
|
||||
4
cmd/gen/doc.go
Normal file
4
cmd/gen/doc.go
Normal file
@@ -0,0 +1,4 @@
|
||||
/*
|
||||
vault_totp_gen: Generate a TOTP code from a TOTP secret.
|
||||
*/
|
||||
package main
|
||||
180
cmd/gen/funcs.go
Normal file
180
cmd/gen/funcs.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
`context`
|
||||
`encoding/json`
|
||||
`fmt`
|
||||
`log`
|
||||
`net/http`
|
||||
`path`
|
||||
`strconv`
|
||||
`time`
|
||||
|
||||
`github.com/davecgh/go-spew/spew`
|
||||
`github.com/hashicorp/vault-client-go`
|
||||
`github.com/pterm/pterm`
|
||||
`r00t2.io/vault_totp/common`
|
||||
`r00t2.io/vault_totp/errs`
|
||||
)
|
||||
|
||||
// displayCode fetches a code (getCode) and displays it (or rather, sets up the display for it) according to the settings from args global.
|
||||
func displayCode(keyNm string) {
|
||||
|
||||
var err error
|
||||
var code string
|
||||
var cfg *otpCfg
|
||||
var itersLeft int
|
||||
var infinite bool
|
||||
var spinner pterm.SpinnerPrinter
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
infinite = args.GenArgs.Repeat < 0
|
||||
|
||||
if args.GenArgs.NoCtr {
|
||||
if args.GenArgs.Repeat
|
||||
}
|
||||
|
||||
if code, cfg, err = getCode(ctx, keyNm, args.GenArgs.VaultTotpMnt, args.GenArgs.Readable); err != nil {
|
||||
logger.Err("displayCode: Received error getting TOTP code and configuration: %v")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
expiryCb is a callback used to pull the ['Date' header] from the response.
|
||||
|
||||
See [hashicorp/vault#31684], [openbao/openbao#2233].
|
||||
|
||||
['Date' header]: https://datatracker.ietf.org/doc/html/rfc9110#section-6.6.1
|
||||
[hashicorp/vault#31684]: https://github.com/hashicorp/vault/issues/31684
|
||||
[openbao/openbao#2233]: https://github.com/openbao/openbao/issues/2233
|
||||
*/
|
||||
func expiryCb(req *http.Request, resp *http.Response) {
|
||||
|
||||
var err error
|
||||
var i any
|
||||
var n int
|
||||
var ok bool
|
||||
var keyNm string
|
||||
|
||||
keyNm = path.Base(req.URL.Path)
|
||||
|
||||
otpCfgs[keyNm] = &otpCfg{
|
||||
keyNm: keyNm,
|
||||
respDate: time.Time{},
|
||||
period: 0,
|
||||
timeStep: 0,
|
||||
expiry: time.Time{},
|
||||
}
|
||||
|
||||
if i, ok = kinfo["period"]; !ok {
|
||||
logger.Err("expiryCb: No period found for key '%s' in kinfo", keyNm)
|
||||
return
|
||||
}
|
||||
|
||||
switch t := i.(type) {
|
||||
case string:
|
||||
// If it's an int string, it's seconds.
|
||||
if n, err = strconv.Atoi(t); err != nil {
|
||||
logger.Warning("expiryCb: Invalid period integer for key '%s': %#v: %v", keyNm, i, err)
|
||||
// It's not a pure int string, so try a time.Duration string (e.g. "30s").
|
||||
if otpCfgs[keyNm].period, err = time.ParseDuration(t); err != nil {
|
||||
logger.Err("expiryCb: Invalid period duration for key '%s': %#v: %v", keyNm, i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
case json.Number:
|
||||
// But I think it's actually a json.Number...
|
||||
if n, err = strconv.Atoi(string(t)); err != nil {
|
||||
logger.Warning("expiryCb: Invalid period json.Number for key '%s': %#v: %v", i, keyNm, err)
|
||||
return
|
||||
}
|
||||
case int:
|
||||
n = t
|
||||
default:
|
||||
logger.Err("expiryCb: Invalid period type for key '%s' (%#T): %#v", keyNm, i, i)
|
||||
return
|
||||
}
|
||||
if otpCfgs[keyNm].period == 0 && n != 0 {
|
||||
// Golang is weird like this but basically time.Duration(n) isn't *actually* meaningful,
|
||||
// it's just necessary for type matching.
|
||||
otpCfgs[keyNm].period = time.Second * time.Duration(n)
|
||||
} else if n == 0 {
|
||||
logger.Err("expiryCb: Could not derive time primitive for key '%s' from '%#v'", keyNm, i)
|
||||
return
|
||||
}
|
||||
|
||||
if otpCfgs[keyNm].respDate, err = http.ParseTime(resp.Header.Get("Date")); err != nil {
|
||||
logger.Err(
|
||||
"expiryCb: received error parsing 'Date' header ('%s') for key '%s': %v",
|
||||
resp.Header.Get("Date"), keyNm, err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
otpCfgs[keyNm].timeStep = common.TimeStepFromTime(otpCfgs[keyNm].respDate, otpCfgs[keyNm].period)
|
||||
otpCfgs[keyNm].startTimeStep, otpCfgs[keyNm].expiry = common.TimeStepToTime(otpCfgs[keyNm].timeStep, otpCfgs[keyNm].period)
|
||||
|
||||
logger.Debug("expiryCb: Derived expiration for key '%s':\n%s", keyNm, spew.Sdump(otpCfgs[keyNm]))
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
getCode gets an OTP code from Vault (and populates the corresponding cfg).
|
||||
|
||||
The code will be split with a space if `readable` is true.
|
||||
*/
|
||||
func getCode(ctx context.Context, keyNm, mntPt string, readable bool) (code string, cfg *otpCfg, err error) {
|
||||
|
||||
var i any
|
||||
var ok bool
|
||||
var resp *vault.Response[map[string]interface{}]
|
||||
|
||||
if resp, err = vc.Secrets.TotpGenerateCode(
|
||||
ctx,
|
||||
keyNm,
|
||||
vault.WithMountPath(mntPt),
|
||||
vault.WithResponseCallbacks(expiryCb),
|
||||
); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
|
||||
if i, ok = resp.Data["code"]; !ok {
|
||||
err = errs.ErrNoCode
|
||||
logger.Err("getCode: Key '%s': %v", keyNm, err)
|
||||
return
|
||||
}
|
||||
|
||||
if cfg, ok = otpCfgs[keyNm]; !ok {
|
||||
err = errs.ErrNoCfg
|
||||
logger.Err("getCode: Key '%s': %v", keyNm, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch t := i.(type) {
|
||||
case string:
|
||||
code = t
|
||||
case json.Number:
|
||||
code = string(t)
|
||||
case int:
|
||||
code = strconv.Itoa(t)
|
||||
default:
|
||||
logger.Err("getCode: Invalid type for key '%s' (%#T): %#v", keyNm, i, i)
|
||||
err = errs.ErrNoCode
|
||||
return
|
||||
}
|
||||
|
||||
if readable {
|
||||
switch len(code) {
|
||||
case 6:
|
||||
code = fmt.Sprintf("%s %s", code[:3], code[3:])
|
||||
case 8:
|
||||
code = fmt.Sprintf("%s %s", code[:4], code[4:])
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
104
cmd/gen/main.go
Normal file
104
cmd/gen/main.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
`fmt`
|
||||
`log`
|
||||
`os`
|
||||
`os/signal`
|
||||
`sort`
|
||||
`strings`
|
||||
`time`
|
||||
|
||||
`github.com/chelnak/ysmrr`
|
||||
`github.com/chelnak/ysmrr/pkg/animations`
|
||||
`r00t2.io/vault_totp/common`
|
||||
`r00t2.io/vault_totp/internal`
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
var err error
|
||||
var idx int
|
||||
var rpt int
|
||||
var doExit bool
|
||||
var keyNm string
|
||||
var keyNms []string
|
||||
var ticker *time.Ticker
|
||||
var keys map[string]struct{} = make(map[string]struct{})
|
||||
|
||||
log.SetOutput(os.Stdout)
|
||||
|
||||
ctx, cancelFunc = signal.NotifyContext(ctx, common.ProgEndSigs...)
|
||||
|
||||
if doExit, err = internal.PrepParser("gen", args.CommonArgs, parser); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
if doExit {
|
||||
return
|
||||
}
|
||||
logger = internal.Logger
|
||||
|
||||
if err = internal.Validate(args); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
|
||||
if vc, err = internal.GetVaultClient(&args.VaultArgs); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
|
||||
if args.GenArgs.KeyNm == "" {
|
||||
if keys, err = internal.ListTotpKeys(ctx, vc, args.GenArgs.VaultTotpMnt); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
keyNms = make([]string, 0, len(keys))
|
||||
for keyNm, _ = range args.GenArgs.KeyNm {
|
||||
keyNms[idx] = keyNm
|
||||
idx++
|
||||
}
|
||||
sort.Strings(keyNms)
|
||||
fmt.Printf(
|
||||
"No key name provided.\n"+
|
||||
"Existing key names at mount '%s' are:\n\n"+
|
||||
"\t%s\n",
|
||||
args.GenArgs.VaultTotpMnt,
|
||||
strings.Join(keyNms, "\n\t"),
|
||||
)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if kinfo, err = internal.GetTotpKeys(ctx, vc, args.GenArgs.VaultTotpMnt); err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
if kinfo == nil || len(kinfo) == 0 {
|
||||
log.Panicln("no TOTP configuration found")
|
||||
}
|
||||
|
||||
if !args.GenArgs.NoCtr {
|
||||
if len(args.GenArgs.Plain) > 3 {
|
||||
args.GenArgs.Plain = args.GenArgs.Plain[:3]
|
||||
}
|
||||
charSet = spinnerChars[len(args.GenArgs.Plain)]
|
||||
}
|
||||
|
||||
if args.GenArgs.Repeat < 0 {
|
||||
// ticker = time.NewTicker(cfg.period + (time.Millisecond * time.Duration(500)))
|
||||
ticker = time.NewTicker(cfg.period)
|
||||
breakLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break breakLoop
|
||||
case <-ticker.C:
|
||||
// GET CODE
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for rpt = args.GenArgs.Repeat; rpt >= 0; rpt-- {
|
||||
fmt.Fprintf(os.Stderr, "(Issuance round #%d/%d)\n", args.GenArgs.Repeat-rpt, rpt)
|
||||
// GET CODE
|
||||
}
|
||||
}
|
||||
|
||||
// Force close any remaining timing loops, etc.
|
||||
cancelFunc()
|
||||
}
|
||||
39
cmd/gen/types.go
Normal file
39
cmd/gen/types.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
`time`
|
||||
)
|
||||
|
||||
type (
|
||||
// otpCfg is used to pull in several different TOTP information. TODO.
|
||||
otpCfg struct {
|
||||
// keyNm is the key name this configuration info is associated with.
|
||||
keyNm string
|
||||
/*
|
||||
respDate is pulled in *separately* from the `date` HTTP header during *code generation*.
|
||||
|
||||
It's very unlikely to be accurate to the TOTP code generation from Vault.
|
||||
|
||||
https://github.com/hashicorp/vault/issues/31684
|
||||
https://github.com/openbao/openbao/issues/2233
|
||||
*/
|
||||
respDate time.Time
|
||||
// period is fetched during *initialization*.
|
||||
period time.Duration
|
||||
/*
|
||||
timeStep is the time step identifier.
|
||||
|
||||
It's created via:
|
||||
|
||||
otpCfg.timeStep = int64(math.Floor(float64(otpCfg.respDate.Unix()) / float64(optCfg.period)))
|
||||
*/
|
||||
timeStep int64
|
||||
// startTimeStep is when beginning of the current period.
|
||||
startTimeStep time.Time
|
||||
/*
|
||||
expiry is the exact timestamp (at least to the level that is... reasonably determinable,
|
||||
currently, depending on the above issues under respDate) that the code expires.
|
||||
*/
|
||||
expiry time.Time
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user