feat: Phase 1 — config, auth, OAuth2 PKCE, CLI scaffold, token store
This commit is contained in:
286
internal/auth/auth.go
Normal file
286
internal/auth/auth.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
"git.warky.dev/wdevs/gocalgoo/internal/store"
|
||||
)
|
||||
|
||||
type AuthStatus struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
Account string `json:"account,omitempty"`
|
||||
Expiry time.Time `json:"expiry,omitempty"`
|
||||
Expired bool `json:"expired"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
cfg ManagerConfig
|
||||
tokenStore *store.TokenStore
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
type ManagerConfig struct {
|
||||
ClientCredentialsFile string
|
||||
TokenStoreFile string
|
||||
Scopes []string
|
||||
DefaultPort int
|
||||
OpenBrowser bool
|
||||
CallbackPath string
|
||||
}
|
||||
|
||||
func NewManager(cfg ManagerConfig, tokenStore *store.TokenStore, log *zap.Logger) *Manager {
|
||||
return &Manager{cfg: cfg, tokenStore: tokenStore, log: log}
|
||||
}
|
||||
|
||||
func (m *Manager) Status(ctx context.Context) (AuthStatus, error) {
|
||||
token, err := m.tokenStore.Load()
|
||||
if err != nil {
|
||||
return AuthStatus{}, fmt.Errorf("load token: %w", err)
|
||||
}
|
||||
if token == nil {
|
||||
return AuthStatus{Authenticated: false}, nil
|
||||
}
|
||||
return AuthStatus{
|
||||
Authenticated: true,
|
||||
Account: token.Account,
|
||||
Expiry: token.Expiry,
|
||||
Expired: token.IsExpired(),
|
||||
Scopes: token.Scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Logout(ctx context.Context) error {
|
||||
if err := m.tokenStore.Delete(); err != nil {
|
||||
return fmt.Errorf("delete token: %w", err)
|
||||
}
|
||||
m.log.Info("logged out")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) LoginLoopback(ctx context.Context, port int) error {
|
||||
oauthCfg, err := m.loadOAuthConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pkce, err := NewPKCEChallenge()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate pkce: %w", err)
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate state: %w", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("bind callback listener: %w", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
actualPort := ln.Addr().(*net.TCPAddr).Port
|
||||
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", actualPort, m.cfg.CallbackPath)
|
||||
oauthCfg.RedirectURL = redirectURI
|
||||
|
||||
authURL := oauthCfg.AuthCodeURL(state,
|
||||
oauth2.AccessTypeOffline,
|
||||
oauth2.SetAuthURLParam("code_challenge", pkce.Challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", pkce.Method),
|
||||
)
|
||||
|
||||
m.log.Info("starting OAuth2 loopback flow",
|
||||
zap.Int("port", actualPort),
|
||||
zap.String("redirect_uri", redirectURI),
|
||||
)
|
||||
|
||||
if port == 0 {
|
||||
fmt.Printf("Listening on port %d\n", actualPort)
|
||||
fmt.Printf("Redirect URI: %s\n", redirectURI)
|
||||
}
|
||||
|
||||
if m.cfg.OpenBrowser {
|
||||
if err := openBrowser(authURL); err != nil {
|
||||
m.log.Warn("could not open browser", zap.Error(err))
|
||||
fmt.Printf("Open this URL in your browser:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Open this URL in your browser:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
codeCh := make(chan string, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
srv := &http.Server{}
|
||||
srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
if q.Get("state") != state {
|
||||
http.Error(w, "invalid state", http.StatusBadRequest)
|
||||
errCh <- fmt.Errorf("oauth state mismatch")
|
||||
return
|
||||
}
|
||||
code := q.Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
errCh <- fmt.Errorf("no code in callback")
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w, "Authentication successful. You may close this tab.")
|
||||
codeCh <- code
|
||||
})
|
||||
|
||||
go func() {
|
||||
if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
errCh <- fmt.Errorf("callback server: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var code string
|
||||
select {
|
||||
case code = <-codeCh:
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
shutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(shutCtx)
|
||||
|
||||
return m.exchangeAndStore(ctx, oauthCfg, code, pkce.Verifier)
|
||||
}
|
||||
|
||||
func (m *Manager) LoginManual(ctx context.Context, port int) error {
|
||||
oauthCfg, err := m.loadOAuthConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pkce, err := NewPKCEChallenge()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate pkce: %w", err)
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate state: %w", err)
|
||||
}
|
||||
|
||||
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, m.cfg.CallbackPath)
|
||||
oauthCfg.RedirectURL = redirectURI
|
||||
|
||||
authURL := oauthCfg.AuthCodeURL(state,
|
||||
oauth2.AccessTypeOffline,
|
||||
oauth2.SetAuthURLParam("code_challenge", pkce.Challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", pkce.Method),
|
||||
)
|
||||
|
||||
fmt.Println("Open this URL in your browser:")
|
||||
fmt.Println(authURL)
|
||||
fmt.Println()
|
||||
fmt.Print("Paste the redirect URL or authorization code: ")
|
||||
|
||||
var input string
|
||||
if _, err := fmt.Scanln(&input); err != nil {
|
||||
return fmt.Errorf("read input: %w", err)
|
||||
}
|
||||
|
||||
code := extractCode(input, state)
|
||||
if code == "" {
|
||||
return fmt.Errorf("could not extract authorization code from input")
|
||||
}
|
||||
|
||||
return m.exchangeAndStore(ctx, oauthCfg, code, pkce.Verifier)
|
||||
}
|
||||
|
||||
func (m *Manager) loadOAuthConfig() (*oauth2.Config, error) {
|
||||
data, err := readFile(m.cfg.ClientCredentialsFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read credentials file %q: %w", m.cfg.ClientCredentialsFile, err)
|
||||
}
|
||||
cfg, err := google.ConfigFromJSON(data, m.cfg.Scopes...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse credentials: %w", err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *Manager) exchangeAndStore(ctx context.Context, cfg *oauth2.Config, code, verifier string) error {
|
||||
token, err := cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exchange code: %w", err)
|
||||
}
|
||||
|
||||
ts := &store.TokenSet{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
TokenType: token.TokenType,
|
||||
Expiry: token.Expiry,
|
||||
Scopes: m.cfg.Scopes,
|
||||
}
|
||||
|
||||
if err := m.tokenStore.Save(ts); err != nil {
|
||||
return fmt.Errorf("save token: %w", err)
|
||||
}
|
||||
|
||||
m.log.Info("authentication successful")
|
||||
fmt.Println("Authentication successful.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate state: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func extractCode(input, expectedState string) string {
|
||||
if len(input) > 4 && input[:4] == "http" {
|
||||
u, err := parseURL(input)
|
||||
if err == nil {
|
||||
q := u.Query()
|
||||
if expectedState != "" && q.Get("state") != expectedState {
|
||||
return ""
|
||||
}
|
||||
if code := q.Get("code"); code != "" {
|
||||
return code
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
var cmd string
|
||||
var args []string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = "open"
|
||||
args = []string{url}
|
||||
case "windows":
|
||||
cmd = "rundll32"
|
||||
args = []string{"url.dll,FileProtocolHandler", url}
|
||||
default:
|
||||
cmd = "xdg-open"
|
||||
args = []string{url}
|
||||
}
|
||||
return exec.Command(cmd, args...).Start()
|
||||
}
|
||||
19
internal/auth/helpers.go
Normal file
19
internal/auth/helpers.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
)
|
||||
|
||||
func readFile(path string) ([]byte, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func parseURL(raw string) (*url.URL, error) {
|
||||
return url.Parse(raw)
|
||||
}
|
||||
33
internal/auth/pkce.go
Normal file
33
internal/auth/pkce.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type PKCEChallenge struct {
|
||||
Verifier string
|
||||
Challenge string
|
||||
Method string
|
||||
}
|
||||
|
||||
func NewPKCEChallenge() (*PKCEChallenge, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, fmt.Errorf("generate pkce verifier: %w", err)
|
||||
}
|
||||
verifier := base64.RawURLEncoding.EncodeToString(b)
|
||||
challenge := computeChallenge(verifier)
|
||||
return &PKCEChallenge{
|
||||
Verifier: verifier,
|
||||
Challenge: challenge,
|
||||
Method: "S256",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func computeChallenge(verifier string) string {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||
}
|
||||
32
internal/auth/pkce_test.go
Normal file
32
internal/auth/pkce_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPKCEChallenge(t *testing.T) {
|
||||
p, err := NewPKCEChallenge()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, p.Verifier)
|
||||
assert.NotEmpty(t, p.Challenge)
|
||||
assert.Equal(t, "S256", p.Method)
|
||||
assert.NotEqual(t, p.Verifier, p.Challenge)
|
||||
}
|
||||
|
||||
func TestPKCEChallengeUniqueness(t *testing.T) {
|
||||
p1, err := NewPKCEChallenge()
|
||||
require.NoError(t, err)
|
||||
p2, err := NewPKCEChallenge()
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, p1.Verifier, p2.Verifier)
|
||||
assert.NotEqual(t, p1.Challenge, p2.Challenge)
|
||||
}
|
||||
|
||||
func TestComputeChallenge(t *testing.T) {
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
assert.Equal(t, expected, computeChallenge(verifier))
|
||||
}
|
||||
Reference in New Issue
Block a user