Files
amcs/internal/app/app.go
Hein 56c84df342 feat(auth): implement OAuth 2.0 authorization code flow and dynamic client registration
- Add OAuth 2.0 support with authorization code flow and dynamic client registration.
- Introduce new handlers for OAuth metadata, client registration, authorization, and token issuance.
- Enhance authentication middleware to support OAuth client credentials.
- Create in-memory stores for authorization codes and tokens.
- Update configuration to include OAuth client details.
- Ensure validation checks for OAuth clients in the configuration.
2026-03-26 21:17:55 +02:00

168 lines
5.1 KiB
Go

package app
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"git.warky.dev/wdevs/amcs/internal/ai"
"git.warky.dev/wdevs/amcs/internal/auth"
"git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/mcpserver"
"git.warky.dev/wdevs/amcs/internal/observability"
"git.warky.dev/wdevs/amcs/internal/session"
"git.warky.dev/wdevs/amcs/internal/store"
"git.warky.dev/wdevs/amcs/internal/tools"
)
func Run(ctx context.Context, configPath string) error {
cfg, loadedFrom, err := config.Load(configPath)
if err != nil {
return err
}
logger, err := observability.NewLogger(cfg.Logging)
if err != nil {
return err
}
logger.Info("loaded configuration",
slog.String("path", loadedFrom),
slog.String("provider", cfg.AI.Provider),
)
db, err := store.New(ctx, cfg.Database)
if err != nil {
return err
}
defer db.Close()
if err := db.VerifyRequirements(ctx); err != nil {
return err
}
httpClient := &http.Client{Timeout: 30 * time.Second}
provider, err := ai.NewProvider(cfg.AI, httpClient, logger)
if err != nil {
return err
}
var keyring *auth.Keyring
var oauthRegistry *auth.OAuthRegistry
var tokenStore *auth.TokenStore
if len(cfg.Auth.Keys) > 0 {
keyring, err = auth.NewKeyring(cfg.Auth.Keys)
if err != nil {
return err
}
}
if len(cfg.Auth.OAuth.Clients) > 0 {
oauthRegistry, err = auth.NewOAuthRegistry(cfg.Auth.OAuth.Clients)
if err != nil {
return err
}
tokenStore = auth.NewTokenStore(0)
}
authCodes := auth.NewAuthCodeStore()
dynClients := auth.NewDynamicClientStore()
activeProjects := session.NewActiveProjects()
logger.Info("database connection verified",
slog.String("provider", provider.Name()),
)
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects),
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
}
errCh := make(chan error, 1)
go func() {
logger.Info("starting HTTP server",
slog.String("addr", server.Addr),
slog.String("mcp_path", cfg.MCP.Path),
)
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
}()
select {
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
logger.Info("shutting down HTTP server")
return server.Shutdown(shutdownCtx)
case err := <-errCh:
return fmt.Errorf("run server: %w", err)
}
}
func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) http.Handler {
mux := http.NewServeMux()
toolSet := mcpserver.ToolSet{
Capture: tools.NewCaptureTool(db, provider, cfg.Capture, activeProjects, logger),
Search: tools.NewSearchTool(db, provider, cfg.Search, activeProjects),
List: tools.NewListTool(db, cfg.Search, activeProjects),
Stats: tools.NewStatsTool(db),
Get: tools.NewGetTool(db),
Update: tools.NewUpdateTool(db, provider, cfg.Capture, logger),
Delete: tools.NewDeleteTool(db),
Archive: tools.NewArchiveTool(db),
Projects: tools.NewProjectsTool(db, activeProjects),
Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects),
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
Links: tools.NewLinksTool(db, provider, cfg.Search),
}
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, logger)(mcpHandler))
if oauthRegistry != nil && tokenStore != nil {
mux.HandleFunc("/.well-known/oauth-authorization-server", oauthMetadataHandler())
mux.HandleFunc("/oauth/register", oauthRegisterHandler(dynClients, logger))
mux.HandleFunc("/oauth/authorize", oauthAuthorizeHandler(dynClients, authCodes, logger))
mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger))
}
mux.HandleFunc("/favicon.ico", serveFavicon)
mux.HandleFunc("/llm", serveLLMInstructions)
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
mux.HandleFunc("/readyz", func(w http.ResponseWriter, r *http.Request) {
if err := db.Ready(r.Context()); err != nil {
logger.Error("readiness check failed", slog.String("error", err.Error()))
http.Error(w, "not ready", http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("amcs is running"))
})
return observability.Chain(
mux,
observability.RequestID(),
observability.Recover(logger),
observability.AccessLog(logger),
observability.Timeout(cfg.Server.WriteTimeout),
)
}