diff --git a/go.mod b/go.mod index 3335ba41..a70100c4 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/hashicorp/go-version v1.8.0 github.com/ktr0731/go-fuzzyfinder v0.9.0 github.com/manifoldco/promptui v0.9.0 + github.com/nxadm/tail v1.4.11 github.com/opencontainers/image-spec v1.1.1 github.com/pelletier/go-toml/v2 v2.2.4 github.com/spf13/cobra v1.10.2 @@ -30,6 +31,7 @@ require ( github.com/testcontainers/testcontainers-go v0.40.0 github.com/zalando/go-keyring v0.2.6 go.etcd.io/bbolt v1.4.0-alpha.1 + go.uber.org/multierr v1.11.0 golang.org/x/term v0.37.0 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 @@ -185,7 +187,6 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nsf/termbox-go v1.1.1 // indirect - github.com/nxadm/tail v1.4.11 // indirect github.com/oasisprotocol/curve25519-voi v0.0.0-20230904125328-1f23a7beb09a // indirect github.com/oklog/run v1.1.0 // indirect github.com/onsi/ginkgo v1.16.5 // indirect @@ -236,7 +237,6 @@ require ( go.opentelemetry.io/otel/trace v1.39.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.uber.org/mock v0.6.0 // indirect - go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/arch v0.15.0 // indirect golang.org/x/crypto v0.45.0 // indirect diff --git a/go.sum b/go.sum index 15a8b0d1..fd34654b 100644 --- a/go.sum +++ b/go.sum @@ -70,6 +70,8 @@ github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= +github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -115,6 +117,8 @@ github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7 github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= diff --git a/internal/application/devnet/provision.go b/internal/application/devnet/provision.go index 428090cf..856c8085 100644 --- a/internal/application/devnet/provision.go +++ b/internal/application/devnet/provision.go @@ -31,6 +31,17 @@ type ProvisionUseCase struct { logger ports.Logger } +type provisionPipelineState struct { + metadata *ports.DevnetMetadata + rpcEndpoint string + genesis []byte + chainID string + bech32Prefix string + accountsDir string + nodes []*ports.NodeMetadata + validators []ports.ValidatorInfo +} + // NewProvisionUseCase creates a new ProvisionUseCase. func NewProvisionUseCase( devnetRepo ports.DevnetRepository, @@ -58,20 +69,36 @@ func NewProvisionUseCase( func (uc *ProvisionUseCase) Execute(ctx context.Context, input dto.ProvisionInput) (*dto.ProvisionOutput, error) { uc.logger.Info("Provisioning devnet...") - // Check if devnet already exists + pipelineState, err := uc.prepareMetadata(input) + if err != nil { + return nil, err + } + + if err := uc.fetchGenesis(ctx, input, pipelineState); err != nil { + return nil, err + } + + if err := uc.initializeKeysAndNodes(ctx, input, pipelineState); err != nil { + return nil, err + } + + if err := uc.patchGenesis(ctx, input, pipelineState); err != nil { + return nil, err + } + + return uc.persistState(ctx, input, pipelineState) +} + +func (uc *ProvisionUseCase) prepareMetadata(input dto.ProvisionInput) (*provisionPipelineState, error) { if uc.devnetRepo.Exists(input.HomeDir) { return nil, fmt.Errorf("devnet already exists at %s", input.HomeDir) } - // Determine execution mode - var execMode types.ExecutionMode + execMode := types.ExecutionModeLocal if input.Mode == string(types.ExecutionModeDocker) { execMode = types.ExecutionModeDocker - } else { - execMode = types.ExecutionModeLocal } - // Create metadata metadata := &ports.DevnetMetadata{ HomeDir: input.HomeDir, NetworkName: input.Network, @@ -83,163 +110,171 @@ func (uc *ProvisionUseCase) Execute(ctx context.Context, input dto.ProvisionInpu Status: ports.StateCreated, DockerImage: input.DockerImage, CustomBinaryPath: input.CustomBinaryPath, - InitialVersion: input.StableVersion, // Store the deployed version - CurrentVersion: input.StableVersion, // Initially same as deployed version + InitialVersion: input.StableVersion, + CurrentVersion: input.StableVersion, CreatedAt: time.Now(), } - // Get RPC endpoint for fetching genesis - rpcEndpoint := "" - if uc.networkModule != nil { - rpcEndpoint = uc.networkModule.RPCEndpoint(input.Network) + if uc.networkModule == nil { + return nil, fmt.Errorf("no RPC endpoint available for network: %s", input.Network) } - - // Fetch genesis from RPC (required for initial provisioning) + rpcEndpoint := uc.networkModule.RPCEndpoint(input.Network) if rpcEndpoint == "" { return nil, fmt.Errorf("no RPC endpoint available for network: %s", input.Network) } - uc.logger.Info("Fetching genesis from RPC %s...", rpcEndpoint) - rpcGenesis, err := uc.genesisSvc.FetchFromRPC(ctx, rpcEndpoint) + return &provisionPipelineState{ + metadata: metadata, + rpcEndpoint: rpcEndpoint, + bech32Prefix: uc.networkModule.Bech32Prefix(), + accountsDir: paths.DevnetAccountsPath(input.HomeDir), + }, nil +} + +func (uc *ProvisionUseCase) fetchGenesis(ctx context.Context, input dto.ProvisionInput, state *provisionPipelineState) error { + uc.logger.Info("Fetching genesis from RPC %s...", state.rpcEndpoint) + rpcGenesis, err := uc.genesisSvc.FetchFromRPC(ctx, state.rpcEndpoint) if err != nil { - return nil, fmt.Errorf("failed to fetch genesis from RPC: %w", err) + return fmt.Errorf("failed to fetch genesis from RPC: %w", err) } - // Use snapshot-based export if requested - var genesis []byte + genesis := rpcGenesis if input.UseSnapshot && uc.stateExportSvc != nil { uc.logger.Info("Exporting genesis from snapshot state...") genesis, err = uc.exportGenesisFromSnapshot(ctx, input, rpcGenesis) if err != nil { - return nil, fmt.Errorf("failed to export genesis from snapshot: %w", err) + return fmt.Errorf("failed to export genesis from snapshot: %w", err) } - } else { - genesis = rpcGenesis } - // Determine chain ID to use from genesis chainID, _ := extractChainID(genesis) - metadata.ChainID = chainID + state.metadata.ChainID = chainID + state.genesis = genesis + state.chainID = chainID + + return nil +} - // Step 1: Create account keys for validators (for transaction signing) +func (uc *ProvisionUseCase) initializeKeysAndNodes(ctx context.Context, input dto.ProvisionInput, state *provisionPipelineState) error { uc.logger.Info("Creating validator account keys...") - accountsDir := paths.DevnetAccountsPath(input.HomeDir) - accountKeys, err := uc.createAccountKeys(ctx, accountsDir, input.NumValidators, input.UseTestMnemonic) + accountKeys, err := uc.createAccountKeys(ctx, state.accountsDir, input.NumValidators, input.UseTestMnemonic) if err != nil { - return nil, fmt.Errorf("failed to create account keys: %w", err) + return fmt.Errorf("failed to create account keys: %w", err) } - // Step 2: Initialize nodes to generate consensus keys (for block signing) uc.logger.Info("Initializing validator nodes...") - nodes, err := uc.initializeNodes(ctx, input, chainID) + nodes, err := uc.initializeNodes(ctx, input, state.chainID) if err != nil { - return nil, fmt.Errorf("failed to initialize nodes: %w", err) + return fmt.Errorf("failed to initialize nodes: %w", err) } - // Step 2.1: Save validator key information to JSON files for export-keys command uc.logger.Debug("Saving validator key information...") - if err := uc.saveValidatorKeys(input.HomeDir, accountKeys, uc.networkModule.Bech32Prefix()); err != nil { - return nil, fmt.Errorf("failed to save validator keys: %w", err) + if err := uc.saveValidatorKeys(input.HomeDir, accountKeys, state.bech32Prefix); err != nil { + return fmt.Errorf("failed to save validator keys: %w", err) } - // Step 2.2: Create and save additional account keys (for testing/transactions) if input.NumAccounts > 0 { uc.logger.Info("Creating %d additional account keys...", input.NumAccounts) - additionalAccounts, err := uc.createAdditionalAccountKeys(ctx, accountsDir, input.NumAccounts, input.UseTestMnemonic, input.NumValidators) + additionalAccounts, err := uc.createAdditionalAccountKeys(ctx, state.accountsDir, input.NumAccounts, input.UseTestMnemonic, input.NumValidators) if err != nil { - return nil, fmt.Errorf("failed to create additional account keys: %w", err) + return fmt.Errorf("failed to create additional account keys: %w", err) } uc.logger.Debug("Saving account key information...") if err := uc.saveAccountKeys(input.HomeDir, additionalAccounts); err != nil { - return nil, fmt.Errorf("failed to save account keys: %w", err) + return fmt.Errorf("failed to save account keys: %w", err) } } - // Step 2.5: Configure nodes with network-specific settings (config.toml, app.toml) uc.logger.Info("Configuring node settings...") - if err := uc.configureNodes(ctx, nodes, chainID, input.NumValidators); err != nil { - return nil, fmt.Errorf("failed to configure nodes: %w", err) + if err := uc.configureNodes(ctx, nodes, state.chainID, input.NumValidators); err != nil { + return fmt.Errorf("failed to configure nodes: %w", err) } - // Step 3: Build validator info combining consensus and account keys uc.logger.Info("Building validator info...") - validators, err := uc.buildValidatorInfo(nodes, accountKeys, uc.networkModule.Bech32Prefix()) + validators, err := uc.buildValidatorInfo(nodes, accountKeys, state.bech32Prefix) if err != nil { - return nil, fmt.Errorf("failed to build validator info: %w", err) + return fmt.Errorf("failed to build validator info: %w", err) } - // Step 4: Modify genesis with validators - uc.logger.Info("Modifying genesis for devnet (chainID: %s)...", chainID) + state.nodes = nodes + state.validators = validators + return nil +} + +func (uc *ProvisionUseCase) patchGenesis(ctx context.Context, input dto.ProvisionInput, state *provisionPipelineState) error { + genesis := state.genesis + + uc.logger.Info("Modifying genesis for devnet (chainID: %s)...", state.chainID) if uc.networkModule != nil { opts := ports.GenesisModifyOptions{ - ChainID: chainID, + ChainID: state.chainID, NumValidators: input.NumValidators, - AddValidators: validators, + AddValidators: state.validators, } - // Check genesis size - gRPC has 4MB default limit - const grpcSizeLimit = 4 * 1024 * 1024 // 4MB + const grpcSizeLimit = 4 * 1024 * 1024 if len(genesis) > grpcSizeLimit { - // Use file-based modification for large genesis (e.g., exported mainnet ~90MB) uc.logger.Info("Using file-based genesis modification (size: %.1f MB)", float64(len(genesis))/(1024*1024)) modifiedGenesis, err := uc.modifyGenesisViaFile(ctx, genesis, opts, input.HomeDir) if err != nil { - return nil, fmt.Errorf("failed to modify genesis via file: %w", err) + return fmt.Errorf("failed to modify genesis via file: %w", err) } genesis = modifiedGenesis } else { - // Use standard in-memory modification for small genesis modifiedGenesis, err := uc.networkModule.ModifyGenesis(genesis, opts) if err != nil { - return nil, fmt.Errorf("failed to modify genesis: %w", err) + return fmt.Errorf("failed to modify genesis: %w", err) } genesis = modifiedGenesis } - uc.logger.Debug("Genesis modified with %d validators", len(validators)) + uc.logger.Debug("Genesis modified with %d validators", len(state.validators)) } - // Step 4: Write modified genesis to all nodes - for _, node := range nodes { + for _, node := range state.nodes { genesisPath := filepath.Join(node.HomeDir, "config", "genesis.json") if err := os.WriteFile(genesisPath, genesis, 0644); err != nil { - return nil, fmt.Errorf("failed to write genesis to node %d: %w", node.Index, err) + return fmt.Errorf("failed to write genesis to node %d: %w", node.Index, err) } } - // Set genesis path in metadata - if len(nodes) > 0 { - metadata.GenesisPath = filepath.Join(nodes[0].HomeDir, "config", "genesis.json") + if len(state.nodes) > 0 { + state.metadata.GenesisPath = filepath.Join(state.nodes[0].HomeDir, "config", "genesis.json") } - // Update metadata - metadata.Status = ports.StateProvisioned + state.metadata.Status = ports.StateProvisioned now := time.Now() - metadata.LastProvisioned = &now + state.metadata.LastProvisioned = &now + state.genesis = genesis + + return nil +} - // Save metadata - if err := uc.devnetRepo.Save(ctx, metadata); err != nil { +func (uc *ProvisionUseCase) persistState( + ctx context.Context, + input dto.ProvisionInput, + state *provisionPipelineState, +) (*dto.ProvisionOutput, error) { + if err := uc.devnetRepo.Save(ctx, state.metadata); err != nil { return nil, fmt.Errorf("failed to save metadata: %w", err) } - // Save nodes - for _, node := range nodes { + for _, node := range state.nodes { if err := uc.nodeRepo.Save(ctx, node); err != nil { uc.logger.Warn("Failed to save node %d: %v", node.Index, err) } } - // Build output output := &dto.ProvisionOutput{ HomeDir: input.HomeDir, - ChainID: metadata.ChainID, - GenesisPath: metadata.GenesisPath, + ChainID: state.metadata.ChainID, + GenesisPath: state.metadata.GenesisPath, NumValidators: input.NumValidators, NumAccounts: input.NumAccounts, - Nodes: make([]dto.NodeInfo, len(nodes)), + Nodes: make([]dto.NodeInfo, len(state.nodes)), } - for i, node := range nodes { + for i, node := range state.nodes { output.Nodes[i] = dto.NodeInfo{ Index: node.Index, Name: node.Name, diff --git a/internal/application/devnet/provision_pipeline_test.go b/internal/application/devnet/provision_pipeline_test.go new file mode 100644 index 00000000..898c9dde --- /dev/null +++ b/internal/application/devnet/provision_pipeline_test.go @@ -0,0 +1,264 @@ +package devnet + +import ( + "bytes" + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/altuslabsxyz/devnet-builder/internal/application/dto" + "github.com/altuslabsxyz/devnet-builder/internal/application/ports" + "github.com/altuslabsxyz/devnet-builder/types" +) + +type provisionTestLogger struct{} + +func (provisionTestLogger) Info(string, ...interface{}) {} +func (provisionTestLogger) Warn(string, ...interface{}) {} +func (provisionTestLogger) Error(string, ...interface{}) {} +func (provisionTestLogger) Debug(string, ...interface{}) {} +func (provisionTestLogger) Success(string, ...interface{}) {} +func (provisionTestLogger) Print(string, ...interface{}) {} +func (provisionTestLogger) Println(string, ...interface{}) {} +func (provisionTestLogger) SetVerbose(bool) {} +func (provisionTestLogger) IsVerbose() bool { return false } +func (provisionTestLogger) Writer() io.Writer { return io.Discard } +func (provisionTestLogger) ErrWriter() io.Writer { return io.Discard } + +type provisionTestDevnetRepo struct { + ports.DevnetRepository + exists bool + saveErr error + saved []*ports.DevnetMetadata +} + +func (m *provisionTestDevnetRepo) Exists(string) bool { return m.exists } +func (m *provisionTestDevnetRepo) Save(_ context.Context, metadata *ports.DevnetMetadata) error { + if m.saveErr != nil { + return m.saveErr + } + m.saved = append(m.saved, metadata) + return nil +} + +type provisionTestNodeRepo struct { + ports.NodeRepository + saveErr error + saved []*ports.NodeMetadata +} + +func (m *provisionTestNodeRepo) Save(_ context.Context, node *ports.NodeMetadata) error { + if m.saveErr != nil { + return m.saveErr + } + m.saved = append(m.saved, node) + return nil +} + +type provisionTestGenesisFetcher struct { + ports.GenesisFetcher + genesis []byte + err error +} + +func (m provisionTestGenesisFetcher) FetchFromRPC(context.Context, string) ([]byte, error) { + if m.err != nil { + return nil, m.err + } + return m.genesis, nil +} + +type provisionTestNetworkModule struct { + ports.NetworkModule + rpcEndpoint string + bech32 string +} + +func (m provisionTestNetworkModule) RPCEndpoint(string) string { return m.rpcEndpoint } +func (m provisionTestNetworkModule) Bech32Prefix() string { return m.bech32 } + +type provisionTestNodeInitializer struct { + ports.NodeInitializer + createAccountKeyErr error +} + +func (m provisionTestNodeInitializer) CreateAccountKey(context.Context, string, string) (*ports.AccountKeyInfo, error) { + if m.createAccountKeyErr != nil { + return nil, m.createAccountKeyErr + } + return &ports.AccountKeyInfo{}, nil +} + +func TestPrepareMetadata(t *testing.T) { + t.Run("returns error when devnet already exists", func(t *testing.T) { + uc := &ProvisionUseCase{ + devnetRepo: &provisionTestDevnetRepo{exists: true}, + logger: provisionTestLogger{}, + } + + _, err := uc.prepareMetadata(dto.ProvisionInput{HomeDir: "/tmp/existing"}) + if err == nil || !strings.Contains(err.Error(), "devnet already exists") { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("resolves docker mode and RPC endpoint", func(t *testing.T) { + uc := &ProvisionUseCase{ + devnetRepo: &provisionTestDevnetRepo{}, + networkModule: provisionTestNetworkModule{rpcEndpoint: "https://rpc.example", bech32: "stable"}, + logger: provisionTestLogger{}, + } + + state, err := uc.prepareMetadata(dto.ProvisionInput{ + HomeDir: "/tmp/devnet", + Mode: string(types.ExecutionModeDocker), + Network: "mainnet", + StableVersion: "v1.0.0", + }) + if err != nil { + t.Fatalf("prepareMetadata failed: %v", err) + } + if state.metadata.ExecutionMode != types.ExecutionModeDocker { + t.Fatalf("execution mode mismatch: %s", state.metadata.ExecutionMode) + } + if state.rpcEndpoint != "https://rpc.example" { + t.Fatalf("rpc endpoint mismatch: %s", state.rpcEndpoint) + } + }) +} + +func TestFetchGenesis(t *testing.T) { + uc := &ProvisionUseCase{ + genesisSvc: provisionTestGenesisFetcher{ + genesis: []byte(`{"chain_id":"stable-devnet-1"}`), + }, + logger: provisionTestLogger{}, + } + + state := &provisionPipelineState{ + metadata: &ports.DevnetMetadata{}, + rpcEndpoint: "https://rpc.example", + } + + if err := uc.fetchGenesis(context.Background(), dto.ProvisionInput{}, state); err != nil { + t.Fatalf("fetchGenesis failed: %v", err) + } + if state.chainID != "stable-devnet-1" { + t.Fatalf("chain id mismatch: %s", state.chainID) + } + if state.metadata.ChainID != state.chainID { + t.Fatalf("metadata chain id mismatch: %s", state.metadata.ChainID) + } + if !bytes.Equal(state.genesis, []byte(`{"chain_id":"stable-devnet-1"}`)) { + t.Fatal("expected genesis to match fetched rpc genesis when snapshot is disabled") + } +} + +func TestInitializeKeysAndNodes_FailsOnCreateAccountKeys(t *testing.T) { + uc := &ProvisionUseCase{ + nodeInitializer: provisionTestNodeInitializer{ + createAccountKeyErr: errors.New("keyring unavailable"), + }, + logger: provisionTestLogger{}, + } + + state := &provisionPipelineState{ + accountsDir: t.TempDir(), + chainID: "stable-devnet-1", + } + + err := uc.initializeKeysAndNodes(context.Background(), dto.ProvisionInput{ + HomeDir: t.TempDir(), + NumValidators: 1, + UseTestMnemonic: false, + }, state) + if err == nil || !strings.Contains(err.Error(), "failed to create account keys") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestPatchGenesis(t *testing.T) { + uc := &ProvisionUseCase{logger: provisionTestLogger{}} + nodeHome := t.TempDir() + if err := os.MkdirAll(filepath.Join(nodeHome, "config"), 0755); err != nil { + t.Fatalf("failed to create node config dir: %v", err) + } + + state := &provisionPipelineState{ + metadata: &ports.DevnetMetadata{}, + genesis: []byte(`{"chain_id":"stable-devnet-1"}`), + chainID: "stable-devnet-1", + nodes: []*ports.NodeMetadata{ + {Index: 0, HomeDir: nodeHome}, + }, + } + + if err := uc.patchGenesis(context.Background(), dto.ProvisionInput{ + HomeDir: nodeHome, + NumValidators: 1, + }, state); err != nil { + t.Fatalf("patchGenesis failed: %v", err) + } + + if state.metadata.Status != ports.StateProvisioned { + t.Fatalf("unexpected metadata status: %s", state.metadata.Status) + } + if state.metadata.LastProvisioned == nil { + t.Fatal("expected LastProvisioned to be set") + } + if state.metadata.GenesisPath == "" { + t.Fatal("expected genesis path to be set") + } + + written, err := os.ReadFile(state.metadata.GenesisPath) + if err != nil { + t.Fatalf("failed to read written genesis: %v", err) + } + if !bytes.Equal(written, state.genesis) { + t.Fatalf("written genesis mismatch: %s", string(written)) + } +} + +func TestPersistState(t *testing.T) { + devnetRepo := &provisionTestDevnetRepo{} + nodeRepo := &provisionTestNodeRepo{} + uc := &ProvisionUseCase{ + devnetRepo: devnetRepo, + nodeRepo: nodeRepo, + logger: provisionTestLogger{}, + } + + state := &provisionPipelineState{ + metadata: &ports.DevnetMetadata{ + ChainID: "stable-devnet-1", + GenesisPath: "/tmp/genesis.json", + }, + nodes: []*ports.NodeMetadata{ + {Index: 0, Name: "node0", HomeDir: "/tmp/node0", NodeID: "id-0"}, + {Index: 1, Name: "node1", HomeDir: "/tmp/node1", NodeID: "id-1"}, + }, + } + + output, err := uc.persistState(context.Background(), dto.ProvisionInput{ + HomeDir: "/tmp/devnet", + NumValidators: 2, + NumAccounts: 1, + }, state) + if err != nil { + t.Fatalf("persistState failed: %v", err) + } + + if len(devnetRepo.saved) != 1 { + t.Fatalf("expected one metadata save, got %d", len(devnetRepo.saved)) + } + if len(nodeRepo.saved) != 2 { + t.Fatalf("expected two node saves, got %d", len(nodeRepo.saved)) + } + if output.ChainID != "stable-devnet-1" || len(output.Nodes) != 2 { + t.Fatalf("unexpected output: %+v", output) + } +} diff --git a/internal/application/upgrade/resumable_execute.go b/internal/application/upgrade/resumable_execute.go index 9e1afad7..592b553c 100644 --- a/internal/application/upgrade/resumable_execute.go +++ b/internal/application/upgrade/resumable_execute.go @@ -25,8 +25,34 @@ type ResumableExecuteUpgradeUseCase struct { exportUC ports.ExportUseCase devnetRepo ports.DevnetRepository logger ports.Logger + ops *resumableUpgradeOps } +type resumableUpgradeOps struct { + executeProposal func(context.Context, dto.ProposeInput) (*dto.ProposeOutput, error) + executeVote func(context.Context, dto.ExecuteUpgradeInput, *ports.UpgradeState) (*dto.VoteOutput, error) + waitForUpgradeHeight func(context.Context, int64) error + waitForChainHalt func(context.Context, int64) error + executeSwitchBinary func(context.Context, dto.ExecuteUpgradeInput, *ports.UpgradeState) (*dto.SwitchBinaryOutput, error) + verifyChainResumed func(context.Context, string) (int64, error) + executeExport func(context.Context, dto.ExportInput) (interface{}, error) + updateCurrentVersion func(context.Context, string, string) error + deleteState func(context.Context) error + transitionAndSave func(context.Context, *ports.UpgradeState, ports.ResumableStage, string) error +} + +type resumableGovStageResult struct { + err error + preserveOutputOnError bool +} + +type resumableGovStageHandler func( + context.Context, + dto.ExecuteUpgradeInput, + *ports.UpgradeState, + *dto.ExecuteUpgradeOutput, +) resumableGovStageResult + // NewResumableExecuteUpgradeUseCase creates a new ResumableExecuteUpgradeUseCase. func NewResumableExecuteUpgradeUseCase( executeUC *ExecuteUpgradeUseCase, @@ -41,7 +67,7 @@ func NewResumableExecuteUpgradeUseCase( devnetRepo ports.DevnetRepository, logger ports.Logger, ) *ResumableExecuteUpgradeUseCase { - return &ResumableExecuteUpgradeUseCase{ + uc := &ResumableExecuteUpgradeUseCase{ executeUC: executeUC, proposeUC: proposeUC, voteUC: voteUC, @@ -54,6 +80,50 @@ func NewResumableExecuteUpgradeUseCase( devnetRepo: devnetRepo, logger: logger, } + uc.ops = uc.defaultOps() + return uc +} + +func (uc *ResumableExecuteUpgradeUseCase) defaultOps() *resumableUpgradeOps { + return &resumableUpgradeOps{ + executeProposal: func(ctx context.Context, input dto.ProposeInput) (*dto.ProposeOutput, error) { + return uc.proposeUC.Execute(ctx, input) + }, + executeVote: func(ctx context.Context, input dto.ExecuteUpgradeInput, state *ports.UpgradeState) (*dto.VoteOutput, error) { + return uc.executeVoting(ctx, input, state) + }, + waitForUpgradeHeight: func(ctx context.Context, height int64) error { + return uc.executeUC.waitForUpgradeHeight(ctx, height) + }, + waitForChainHalt: func(ctx context.Context, height int64) error { + return uc.executeUC.waitForChainHalt(ctx, height) + }, + executeSwitchBinary: func(ctx context.Context, input dto.ExecuteUpgradeInput, state *ports.UpgradeState) (*dto.SwitchBinaryOutput, error) { + return uc.executeSwitchBinary(ctx, input, state) + }, + verifyChainResumed: func(ctx context.Context, homeDir string) (int64, error) { + return uc.executeUC.verifyChainResumed(ctx, homeDir) + }, + executeExport: func(ctx context.Context, input dto.ExportInput) (interface{}, error) { + return uc.exportUC.Execute(ctx, input) + }, + updateCurrentVersion: func(ctx context.Context, homeDir, version string) error { + return uc.executeUC.updateCurrentVersion(ctx, homeDir, version) + }, + deleteState: func(ctx context.Context) error { + return uc.stateManager.DeleteState(ctx) + }, + transitionAndSave: func(ctx context.Context, state *ports.UpgradeState, target ports.ResumableStage, reason string) error { + return uc.transitionAndSave(ctx, state, target, reason) + }, + } +} + +func (uc *ResumableExecuteUpgradeUseCase) getOps() *resumableUpgradeOps { + if uc.ops == nil { + uc.ops = uc.defaultOps() + } + return uc.ops } // Execute performs the upgrade workflow with state persistence. @@ -174,225 +244,291 @@ func (uc *ResumableExecuteUpgradeUseCase) executeWithGovResumable( startTime time.Time, ) (*dto.ExecuteUpgradeOutput, error) { output := &dto.ExecuteUpgradeOutput{} + if err := uc.runPreUpgradeExport(ctx, input, state, output); err != nil { + output.Error = err + return output, err + } - // Pre-upgrade export (only if starting fresh and enabled) - if state.Stage == ports.ResumableStageInitialized && input.WithExport { - uc.logger.Info("Pre-upgrade: Exporting state before upgrade...") - exportInput := dto.ExportInput{ - HomeDir: input.HomeDir, - OutputDir: input.GenesisDir, - Force: false, + handlers := uc.govResumableStageHandlers() + for { + if terminalOutput, terminalErr, done := uc.handleGovTerminalStage(ctx, input, state, startTime, output); done { + return terminalOutput, terminalErr } - preExportResultRaw, err := uc.exportUC.Execute(ctx, exportInput) - if err != nil { - uc.logger.Error("Pre-upgrade export failed: %v", err) - output.Error = fmt.Errorf("pre-upgrade export failed: %w", err) - return output, output.Error + handler, ok := handlers[state.Stage] + if !ok { + return nil, fmt.Errorf("cannot resume from stage: %s", state.Stage) } - if preExportResult, ok := preExportResultRaw.(*dto.ExportOutput); ok { - output.PreGenesisPath = preExportResult.ExportPath - uc.logger.Success("Pre-upgrade export complete: %s", preExportResult.ExportPath) - } - } - // Resume from current stage - switch state.Stage { - case ports.ResumableStageInitialized: - // Step 1: Submit proposal - uc.logger.Info("Step 1/5: Submitting upgrade proposal...") - proposeResult, err := uc.proposeUC.Execute(ctx, dto.ProposeInput{ - HomeDir: input.HomeDir, - UpgradeName: input.UpgradeName, - UpgradeHeight: input.UpgradeHeight, - VotingPeriod: input.VotingPeriod, - HeightBuffer: input.HeightBuffer, - }) - if err != nil { - if saveErr := uc.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { - uc.logger.Warn("Failed to save failed state: %v", saveErr) + outcome := handler(ctx, input, state, output) + if outcome.err != nil { + if outcome.preserveOutputOnError { + return output, outcome.err } - output.Error = err - return output, err + return nil, outcome.err } + } +} - // Update state with proposal info - state.ProposalID = proposeResult.ProposalID - state.UpgradeHeight = proposeResult.UpgradeHeight - output.ProposalID = proposeResult.ProposalID - output.UpgradeHeight = proposeResult.UpgradeHeight +func (uc *ResumableExecuteUpgradeUseCase) runPreUpgradeExport( + ctx context.Context, + input dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, +) error { + if state.Stage != ports.ResumableStageInitialized || !input.WithExport { + return nil + } + ops := uc.getOps() - // Transition to ProposalSubmitted - if err := uc.transitionAndSave(ctx, state, ports.ResumableStageProposalSubmitted, fmt.Sprintf("proposal %d submitted", proposeResult.ProposalID)); err != nil { - return nil, err - } - fallthrough + uc.logger.Info("Pre-upgrade: Exporting state before upgrade...") + exportInput := dto.ExportInput{ + HomeDir: input.HomeDir, + OutputDir: input.GenesisDir, + Force: false, + } - case ports.ResumableStageProposalSubmitted: - // Transition to Voting (deposit period complete in devnet context) - if err := uc.transitionAndSave(ctx, state, ports.ResumableStageVoting, "voting period started"); err != nil { - return nil, err - } - fallthrough + preExportResultRaw, err := ops.executeExport(ctx, exportInput) + if err != nil { + uc.logger.Error("Pre-upgrade export failed: %v", err) + return fmt.Errorf("pre-upgrade export failed: %w", err) + } + if preExportResult, ok := preExportResultRaw.(*dto.ExportOutput); ok { + output.PreGenesisPath = preExportResult.ExportPath + uc.logger.Success("Pre-upgrade export complete: %s", preExportResult.ExportPath) + } - case ports.ResumableStageVoting: - // Step 2: Vote from all validators - uc.logger.Info("Step 2/5: Voting from all validators...") - output.ProposalID = state.ProposalID - output.UpgradeHeight = state.UpgradeHeight + return nil +} - voteResult, err := uc.executeVoting(ctx, input, state) - if err != nil { - if saveErr := uc.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { - uc.logger.Warn("Failed to save failed state: %v", saveErr) - } - output.Error = err - return output, err - } +func (uc *ResumableExecuteUpgradeUseCase) govResumableStageHandlers() map[ports.ResumableStage]resumableGovStageHandler { + return map[ports.ResumableStage]resumableGovStageHandler{ + ports.ResumableStageInitialized: uc.handleGovStageInitialized, + ports.ResumableStageProposalSubmitted: uc.handleGovStageProposalSubmitted, + ports.ResumableStageVoting: uc.handleGovStageVoting, + ports.ResumableStageWaitingForHeight: uc.handleGovStageWaitingForHeight, + ports.ResumableStageChainHalted: uc.handleGovStageChainHalted, + ports.ResumableStageSwitchingBinary: uc.handleGovStageSwitchingBinary, + ports.ResumableStageVerifyingResume: uc.handleGovStageVerifyingResume, + } +} - if voteResult.VotesCast != voteResult.TotalVoters { - err := fmt.Errorf("not all votes cast: %d/%d", voteResult.VotesCast, voteResult.TotalVoters) - if saveErr := uc.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { - uc.logger.Warn("Failed to save failed state: %v", saveErr) +func (uc *ResumableExecuteUpgradeUseCase) handleGovTerminalStage( + ctx context.Context, + input dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + startTime time.Time, + output *dto.ExecuteUpgradeOutput, +) (*dto.ExecuteUpgradeOutput, error, bool) { + ops := uc.getOps() + + switch state.Stage { + case ports.ResumableStageCompleted: + if input.TargetVersion != "" { + if err := ops.updateCurrentVersion(ctx, input.HomeDir, input.TargetVersion); err != nil { + uc.logger.Warn("Failed to update version in metadata: %v", err) } - output.Error = err - return output, err } - // Transition to WaitingForHeight - if err := uc.transitionAndSave(ctx, state, ports.ResumableStageWaitingForHeight, "voting complete, proposal passed"); err != nil { - return nil, err + if err := ops.deleteState(ctx); err != nil { + uc.logger.Warn("Failed to delete state file: %v", err) } - fallthrough - case ports.ResumableStageWaitingForHeight: - // Step 3: Wait for upgrade height - uc.logger.Info("Step 3/5: Waiting for upgrade height %d...", state.UpgradeHeight) - output.ProposalID = state.ProposalID - output.UpgradeHeight = state.UpgradeHeight + output.Success = true + output.Duration = time.Since(startTime) + uc.logger.Success("Upgrade complete! Duration: %v", output.Duration) + return output, nil, true + case ports.ResumableStageFailed: + return nil, fmt.Errorf("upgrade previously failed: %s (use --force-restart to start fresh)", state.Error), true + case ports.ResumableStageProposalRejected: + return nil, fmt.Errorf("proposal was rejected (use --force-restart to start fresh)"), true + default: + return nil, nil, false + } +} - if err := uc.executeUC.waitForUpgradeHeight(ctx, state.UpgradeHeight); err != nil { - if saveErr := uc.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { - uc.logger.Warn("Failed to save failed state: %v", saveErr) - } - output.Error = err - return output, err - } +func (uc *ResumableExecuteUpgradeUseCase) handleGovStageInitialized( + ctx context.Context, + input dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, +) resumableGovStageResult { + ops := uc.getOps() + uc.logger.Info("Step 1/5: Submitting upgrade proposal...") - // Transition to ChainHalted - if err := uc.transitionAndSave(ctx, state, ports.ResumableStageChainHalted, "upgrade height reached"); err != nil { - return nil, err - } - fallthrough + proposeResult, err := ops.executeProposal(ctx, dto.ProposeInput{ + HomeDir: input.HomeDir, + UpgradeName: input.UpgradeName, + UpgradeHeight: input.UpgradeHeight, + VotingPeriod: input.VotingPeriod, + HeightBuffer: input.HeightBuffer, + }) + if err != nil { + return uc.failGovStage(ctx, state, output, err) + } - case ports.ResumableStageChainHalted: - // Step 4: Wait for chain halt - uc.logger.Info("Step 4/5: Waiting for chain to halt...") - output.ProposalID = state.ProposalID - output.UpgradeHeight = state.UpgradeHeight + state.ProposalID = proposeResult.ProposalID + state.UpgradeHeight = proposeResult.UpgradeHeight + output.ProposalID = proposeResult.ProposalID + output.UpgradeHeight = proposeResult.UpgradeHeight + + return uc.advanceGovStage( + ctx, + state, + ports.ResumableStageProposalSubmitted, + fmt.Sprintf("proposal %d submitted", proposeResult.ProposalID), + ) +} - if err := uc.executeUC.waitForChainHalt(ctx, state.UpgradeHeight); err != nil { - if saveErr := uc.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { - uc.logger.Warn("Failed to save failed state: %v", saveErr) - } - output.Error = err - return output, err - } +func (uc *ResumableExecuteUpgradeUseCase) handleGovStageProposalSubmitted( + ctx context.Context, + _ dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + _ *dto.ExecuteUpgradeOutput, +) resumableGovStageResult { + return uc.advanceGovStage(ctx, state, ports.ResumableStageVoting, "voting period started") +} - // Transition to SwitchingBinary - if err := uc.transitionAndSave(ctx, state, ports.ResumableStageSwitchingBinary, "chain halted at upgrade height"); err != nil { - return nil, err - } - fallthrough +func (uc *ResumableExecuteUpgradeUseCase) handleGovStageVoting( + ctx context.Context, + input dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, +) resumableGovStageResult { + ops := uc.getOps() + uc.logger.Info("Step 2/5: Voting from all validators...") + output.ProposalID = state.ProposalID + output.UpgradeHeight = state.UpgradeHeight + + voteResult, err := ops.executeVote(ctx, input, state) + if err != nil { + return uc.failGovStage(ctx, state, output, err) + } - case ports.ResumableStageSwitchingBinary: - // Step 5: Switch binary - uc.logger.Info("Step 5/5: Switching binary...") - output.ProposalID = state.ProposalID - output.UpgradeHeight = state.UpgradeHeight + if voteResult.VotesCast != voteResult.TotalVoters { + return uc.failGovStage(ctx, state, output, fmt.Errorf("not all votes cast: %d/%d", voteResult.VotesCast, voteResult.TotalVoters)) + } - switchResult, err := uc.executeSwitchBinary(ctx, input, state) - if err != nil { - if saveErr := uc.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { - uc.logger.Warn("Failed to save failed state: %v", saveErr) - } - output.Error = err - return output, err - } - output.NewBinary = switchResult.NewBinary + return uc.advanceGovStage(ctx, state, ports.ResumableStageWaitingForHeight, "voting complete, proposal passed") +} - // Transition to VerifyingResume - if err := uc.transitionAndSave(ctx, state, ports.ResumableStageVerifyingResume, "binary switch complete"); err != nil { - return nil, err - } - fallthrough +func (uc *ResumableExecuteUpgradeUseCase) handleGovStageWaitingForHeight( + ctx context.Context, + _ dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, +) resumableGovStageResult { + ops := uc.getOps() + uc.logger.Info("Step 3/5: Waiting for upgrade height %d...", state.UpgradeHeight) + output.ProposalID = state.ProposalID + output.UpgradeHeight = state.UpgradeHeight + + if err := ops.waitForUpgradeHeight(ctx, state.UpgradeHeight); err != nil { + return uc.failGovStage(ctx, state, output, err) + } - case ports.ResumableStageVerifyingResume: - // Verify chain resumed - output.ProposalID = state.ProposalID - output.UpgradeHeight = state.UpgradeHeight + return uc.advanceGovStage(ctx, state, ports.ResumableStageChainHalted, "upgrade height reached") +} - postHeight, err := uc.executeUC.verifyChainResumed(ctx, input.HomeDir) - if err != nil { - if saveErr := uc.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { - uc.logger.Warn("Failed to save failed state: %v", saveErr) - } - output.Error = err - return output, err - } - output.PostUpgradeHeight = postHeight +func (uc *ResumableExecuteUpgradeUseCase) handleGovStageChainHalted( + ctx context.Context, + _ dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, +) resumableGovStageResult { + ops := uc.getOps() + uc.logger.Info("Step 4/5: Waiting for chain to halt...") + output.ProposalID = state.ProposalID + output.UpgradeHeight = state.UpgradeHeight + + if err := ops.waitForChainHalt(ctx, state.UpgradeHeight); err != nil { + return uc.failGovStage(ctx, state, output, err) + } - // Post-upgrade export (if enabled) - if input.WithExport { - uc.logger.Info("Post-upgrade: Exporting state after upgrade...") - exportInput := dto.ExportInput{ - HomeDir: input.HomeDir, - OutputDir: input.GenesisDir, - Force: false, - } + return uc.advanceGovStage(ctx, state, ports.ResumableStageSwitchingBinary, "chain halted at upgrade height") +} - postExportResultRaw, err := uc.exportUC.Execute(ctx, exportInput) - if err != nil { - uc.logger.Warn("Post-upgrade export failed: %v", err) - } else if postExportResult, ok := postExportResultRaw.(*dto.ExportOutput); ok { - output.PostGenesisPath = postExportResult.ExportPath - uc.logger.Success("Post-upgrade export complete: %s", postExportResult.ExportPath) - } - } +func (uc *ResumableExecuteUpgradeUseCase) handleGovStageSwitchingBinary( + ctx context.Context, + input dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, +) resumableGovStageResult { + ops := uc.getOps() + uc.logger.Info("Step 5/5: Switching binary...") + output.ProposalID = state.ProposalID + output.UpgradeHeight = state.UpgradeHeight + + switchResult, err := ops.executeSwitchBinary(ctx, input, state) + if err != nil { + return uc.failGovStage(ctx, state, output, err) + } + output.NewBinary = switchResult.NewBinary - // Transition to Completed - if err := uc.transitionAndSave(ctx, state, ports.ResumableStageCompleted, "chain verified healthy"); err != nil { - return nil, err - } - fallthrough + return uc.advanceGovStage(ctx, state, ports.ResumableStageVerifyingResume, "binary switch complete") +} - case ports.ResumableStageCompleted: - // Update metadata version - if input.TargetVersion != "" { - if err := uc.executeUC.updateCurrentVersion(ctx, input.HomeDir, input.TargetVersion); err != nil { - uc.logger.Warn("Failed to update version in metadata: %v", err) - } - } +func (uc *ResumableExecuteUpgradeUseCase) handleGovStageVerifyingResume( + ctx context.Context, + input dto.ExecuteUpgradeInput, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, +) resumableGovStageResult { + ops := uc.getOps() + output.ProposalID = state.ProposalID + output.UpgradeHeight = state.UpgradeHeight - // Delete state file on success - if err := uc.stateManager.DeleteState(ctx); err != nil { - uc.logger.Warn("Failed to delete state file: %v", err) + postHeight, err := ops.verifyChainResumed(ctx, input.HomeDir) + if err != nil { + return uc.failGovStage(ctx, state, output, err) + } + output.PostUpgradeHeight = postHeight + + if input.WithExport { + uc.logger.Info("Post-upgrade: Exporting state after upgrade...") + exportInput := dto.ExportInput{ + HomeDir: input.HomeDir, + OutputDir: input.GenesisDir, + Force: false, } - output.Success = true - output.Duration = time.Since(startTime) - uc.logger.Success("Upgrade complete! Duration: %v", output.Duration) - return output, nil + postExportResultRaw, err := ops.executeExport(ctx, exportInput) + if err != nil { + uc.logger.Warn("Post-upgrade export failed: %v", err) + } else if postExportResult, ok := postExportResultRaw.(*dto.ExportOutput); ok { + output.PostGenesisPath = postExportResult.ExportPath + uc.logger.Success("Post-upgrade export complete: %s", postExportResult.ExportPath) + } + } - case ports.ResumableStageFailed: - return nil, fmt.Errorf("upgrade previously failed: %s (use --force-restart to start fresh)", state.Error) + return uc.advanceGovStage(ctx, state, ports.ResumableStageCompleted, "chain verified healthy") +} - case ports.ResumableStageProposalRejected: - return nil, fmt.Errorf("proposal was rejected (use --force-restart to start fresh)") +func (uc *ResumableExecuteUpgradeUseCase) failGovStage( + ctx context.Context, + state *ports.UpgradeState, + output *dto.ExecuteUpgradeOutput, + err error, +) resumableGovStageResult { + ops := uc.getOps() + if saveErr := ops.transitionAndSave(ctx, state, ports.ResumableStageFailed, err.Error()); saveErr != nil { + uc.logger.Warn("Failed to save failed state: %v", saveErr) + } + output.Error = err + return resumableGovStageResult{err: err, preserveOutputOnError: true} +} - default: - return nil, fmt.Errorf("cannot resume from stage: %s", state.Stage) +func (uc *ResumableExecuteUpgradeUseCase) advanceGovStage( + ctx context.Context, + state *ports.UpgradeState, + target ports.ResumableStage, + reason string, +) resumableGovStageResult { + ops := uc.getOps() + if err := ops.transitionAndSave(ctx, state, target, reason); err != nil { + return resumableGovStageResult{err: err} } + return resumableGovStageResult{} } // executeSwitchBinary handles binary switching with per-node tracking. diff --git a/internal/application/upgrade/resumable_execute_state_machine_test.go b/internal/application/upgrade/resumable_execute_state_machine_test.go new file mode 100644 index 00000000..3c95d813 --- /dev/null +++ b/internal/application/upgrade/resumable_execute_state_machine_test.go @@ -0,0 +1,255 @@ +package upgrade + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/altuslabsxyz/devnet-builder/internal/application/dto" + "github.com/altuslabsxyz/devnet-builder/internal/application/ports" +) + +type upgradeTestLogger struct{} + +func (upgradeTestLogger) Info(string, ...interface{}) {} +func (upgradeTestLogger) Warn(string, ...interface{}) {} +func (upgradeTestLogger) Error(string, ...interface{}) {} +func (upgradeTestLogger) Debug(string, ...interface{}) {} +func (upgradeTestLogger) Success(string, ...interface{}) {} +func (upgradeTestLogger) Print(string, ...interface{}) {} +func (upgradeTestLogger) Println(string, ...interface{}) {} +func (upgradeTestLogger) SetVerbose(bool) {} +func (upgradeTestLogger) IsVerbose() bool { return false } +func (upgradeTestLogger) Writer() io.Writer { return io.Discard } +func (upgradeTestLogger) ErrWriter() io.Writer { return io.Discard } + +func newResumableUCTestHarness(ops *resumableUpgradeOps) *ResumableExecuteUpgradeUseCase { + if ops.transitionAndSave == nil { + ops.transitionAndSave = func(_ context.Context, state *ports.UpgradeState, target ports.ResumableStage, _ string) error { + state.Stage = target + return nil + } + } + return &ResumableExecuteUpgradeUseCase{ + logger: upgradeTestLogger{}, + ops: ops, + } +} + +func TestGovResumableStageHandlersCoverage(t *testing.T) { + uc := newResumableUCTestHarness(&resumableUpgradeOps{}) + handlers := uc.govResumableStageHandlers() + + expected := []ports.ResumableStage{ + ports.ResumableStageInitialized, + ports.ResumableStageProposalSubmitted, + ports.ResumableStageVoting, + ports.ResumableStageWaitingForHeight, + ports.ResumableStageChainHalted, + ports.ResumableStageSwitchingBinary, + ports.ResumableStageVerifyingResume, + } + + if len(handlers) != len(expected) { + t.Fatalf("handler count mismatch: got %d want %d", len(handlers), len(expected)) + } + + for _, stage := range expected { + if _, ok := handlers[stage]; !ok { + t.Fatalf("missing handler for stage %s", stage) + } + } +} + +func TestHandleGovStageInitialized(t *testing.T) { + uc := newResumableUCTestHarness(&resumableUpgradeOps{ + executeProposal: func(context.Context, dto.ProposeInput) (*dto.ProposeOutput, error) { + return &dto.ProposeOutput{ProposalID: 42, UpgradeHeight: 777}, nil + }, + }) + state := ports.NewUpgradeState("upgrade", "local", false) + output := &dto.ExecuteUpgradeOutput{} + + outcome := uc.handleGovStageInitialized(context.Background(), dto.ExecuteUpgradeInput{ + HomeDir: "/tmp/home", + UpgradeName: "upgrade", + }, state, output) + + if outcome.err != nil { + t.Fatalf("unexpected error: %v", outcome.err) + } + if state.Stage != ports.ResumableStageProposalSubmitted { + t.Fatalf("stage mismatch: got %s", state.Stage) + } + if output.ProposalID != 42 || output.UpgradeHeight != 777 { + t.Fatalf("unexpected proposal output: %+v", output) + } +} + +func TestHandleGovStageVoting_PartialVotesMarksFailed(t *testing.T) { + transitions := make([]ports.ResumableStage, 0, 1) + uc := newResumableUCTestHarness(&resumableUpgradeOps{ + executeVote: func(context.Context, dto.ExecuteUpgradeInput, *ports.UpgradeState) (*dto.VoteOutput, error) { + return &dto.VoteOutput{VotesCast: 1, TotalVoters: 2}, nil + }, + transitionAndSave: func(_ context.Context, state *ports.UpgradeState, target ports.ResumableStage, _ string) error { + transitions = append(transitions, target) + state.Stage = target + return nil + }, + }) + state := ports.NewUpgradeState("upgrade", "local", false) + state.Stage = ports.ResumableStageVoting + state.ProposalID = 10 + state.UpgradeHeight = 1234 + output := &dto.ExecuteUpgradeOutput{} + + outcome := uc.handleGovStageVoting(context.Background(), dto.ExecuteUpgradeInput{}, state, output) + if outcome.err == nil { + t.Fatal("expected error for partial votes") + } + if !outcome.preserveOutputOnError { + t.Fatal("expected preserveOutputOnError=true on stage execution failure") + } + if output.Error == nil || !strings.Contains(output.Error.Error(), "not all votes cast") { + t.Fatalf("unexpected output error: %v", output.Error) + } + if len(transitions) != 1 || transitions[0] != ports.ResumableStageFailed { + t.Fatalf("unexpected transitions: %+v", transitions) + } +} + +func TestHandleGovStageSwitchingBinary(t *testing.T) { + uc := newResumableUCTestHarness(&resumableUpgradeOps{ + executeSwitchBinary: func(context.Context, dto.ExecuteUpgradeInput, *ports.UpgradeState) (*dto.SwitchBinaryOutput, error) { + return &dto.SwitchBinaryOutput{NewBinary: "/tmp/newd"}, nil + }, + }) + state := ports.NewUpgradeState("upgrade", "local", false) + state.Stage = ports.ResumableStageSwitchingBinary + state.ProposalID = 9 + state.UpgradeHeight = 999 + output := &dto.ExecuteUpgradeOutput{} + + outcome := uc.handleGovStageSwitchingBinary(context.Background(), dto.ExecuteUpgradeInput{}, state, output) + if outcome.err != nil { + t.Fatalf("unexpected error: %v", outcome.err) + } + if state.Stage != ports.ResumableStageVerifyingResume { + t.Fatalf("stage mismatch: got %s", state.Stage) + } + if output.NewBinary != "/tmp/newd" { + t.Fatalf("new binary mismatch: %s", output.NewBinary) + } +} + +func TestHandleGovStageVerifyingResume_WithExport(t *testing.T) { + exportCalls := 0 + uc := newResumableUCTestHarness(&resumableUpgradeOps{ + verifyChainResumed: func(context.Context, string) (int64, error) { + return 2048, nil + }, + executeExport: func(context.Context, dto.ExportInput) (interface{}, error) { + exportCalls++ + return &dto.ExportOutput{ExportPath: "/tmp/post-export.json"}, nil + }, + }) + state := ports.NewUpgradeState("upgrade", "local", false) + state.Stage = ports.ResumableStageVerifyingResume + state.ProposalID = 11 + state.UpgradeHeight = 2047 + output := &dto.ExecuteUpgradeOutput{} + + outcome := uc.handleGovStageVerifyingResume(context.Background(), dto.ExecuteUpgradeInput{ + HomeDir: "/tmp/home", + GenesisDir: "/tmp/genesis", + WithExport: true, + UpgradeName: "upgrade", + }, state, output) + if outcome.err != nil { + t.Fatalf("unexpected error: %v", outcome.err) + } + if state.Stage != ports.ResumableStageCompleted { + t.Fatalf("stage mismatch: got %s", state.Stage) + } + if output.PostUpgradeHeight != 2048 { + t.Fatalf("post height mismatch: got %d", output.PostUpgradeHeight) + } + if output.PostGenesisPath != "/tmp/post-export.json" { + t.Fatalf("post export path mismatch: %s", output.PostGenesisPath) + } + if exportCalls != 1 { + t.Fatalf("unexpected export calls: %d", exportCalls) + } +} + +func TestExecuteWithGovResumable_StateMachineLoop(t *testing.T) { + var transitions []ports.ResumableStage + exportCalls := 0 + updateCalls := 0 + deleteCalls := 0 + + uc := newResumableUCTestHarness(&resumableUpgradeOps{ + executeProposal: func(context.Context, dto.ProposeInput) (*dto.ProposeOutput, error) { + return &dto.ProposeOutput{ProposalID: 7, UpgradeHeight: 150}, nil + }, + executeVote: func(context.Context, dto.ExecuteUpgradeInput, *ports.UpgradeState) (*dto.VoteOutput, error) { + return &dto.VoteOutput{VotesCast: 4, TotalVoters: 4}, nil + }, + waitForUpgradeHeight: func(context.Context, int64) error { return nil }, + waitForChainHalt: func(context.Context, int64) error { return nil }, + executeSwitchBinary: func(context.Context, dto.ExecuteUpgradeInput, *ports.UpgradeState) (*dto.SwitchBinaryOutput, error) { + return &dto.SwitchBinaryOutput{NewBinary: "/tmp/new-binary"}, nil + }, + verifyChainResumed: func(context.Context, string) (int64, error) { return 151, nil }, + executeExport: func(context.Context, dto.ExportInput) (interface{}, error) { + exportCalls++ + if exportCalls == 1 { + return &dto.ExportOutput{ExportPath: "/tmp/pre.json"}, nil + } + return &dto.ExportOutput{ExportPath: "/tmp/post.json"}, nil + }, + updateCurrentVersion: func(context.Context, string, string) error { + updateCalls++ + return nil + }, + deleteState: func(context.Context) error { + deleteCalls++ + return nil + }, + transitionAndSave: func(_ context.Context, state *ports.UpgradeState, target ports.ResumableStage, _ string) error { + transitions = append(transitions, target) + state.Stage = target + return nil + }, + }) + + state := ports.NewUpgradeState("upgrade", "docker", false) + output, err := uc.executeWithGovResumable(context.Background(), dto.ExecuteUpgradeInput{ + HomeDir: "/tmp/home", + UpgradeName: "upgrade", + TargetVersion: "v2.0.0", + GenesisDir: "/tmp/genesis", + WithExport: true, + }, state, time.Now()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !output.Success { + t.Fatal("expected success output") + } + if output.PreGenesisPath != "/tmp/pre.json" || output.PostGenesisPath != "/tmp/post.json" { + t.Fatalf("unexpected export paths: pre=%s post=%s", output.PreGenesisPath, output.PostGenesisPath) + } + if output.NewBinary != "/tmp/new-binary" || output.PostUpgradeHeight != 151 { + t.Fatalf("unexpected output values: %+v", output) + } + if updateCalls != 1 || deleteCalls != 1 { + t.Fatalf("unexpected finalize calls update=%d delete=%d", updateCalls, deleteCalls) + } + if len(transitions) != 7 { + t.Fatalf("unexpected transitions length: %d (%+v)", len(transitions), transitions) + } +} diff --git a/internal/daemon/server/server.go b/internal/daemon/server/server.go index c6f89aac..941e6e72 100644 --- a/internal/daemon/server/server.go +++ b/internal/daemon/server/server.go @@ -26,6 +26,7 @@ import ( "github.com/altuslabsxyz/devnet-builder/internal/daemon/subnet" "github.com/altuslabsxyz/devnet-builder/internal/daemon/types" "github.com/altuslabsxyz/devnet-builder/internal/daemon/upgrader" + "go.uber.org/multierr" "google.golang.org/grpc" ) @@ -107,16 +108,58 @@ type Server struct { shutdownCancel context.CancelFunc } +type ServerBuilder struct { + config *Config + server *Server + + orchFactory *OrchestratorFactory + devnetProv *provisioner.DevnetProvisioner + + cleanupStack []func() error + err error +} + +// NewServerBuilder creates a new server builder. +func NewServerBuilder(config *Config) *ServerBuilder { + return &ServerBuilder{ + config: config, + server: &Server{config: config}, + } +} + // New creates a new server. func New(config *Config) (*Server, error) { - // Ensure data directory exists first (needed for log file) - if err := os.MkdirAll(config.DataDir, 0755); err != nil { - return nil, fmt.Errorf("failed to create data directory: %w", err) + return NewServerBuilder(config). + WithDataDir(). + WithLogger(). + WithPlugins(). + WithStoreAndSubnet(). + WithControllers(). + WithRuntime(). + WithGRPCServices(). + Build() +} + +// WithDataDir ensures the data directory exists. +func (b *ServerBuilder) WithDataDir() *ServerBuilder { + if b.err != nil { + return b + } + + if err := os.MkdirAll(b.config.DataDir, 0755); err != nil { + b.err = fmt.Errorf("failed to create data directory: %w", err) + } + return b +} + +// WithLogger initializes persistent logging. +func (b *ServerBuilder) WithLogger() *ServerBuilder { + if b.err != nil { + return b } - // Set up logger - write to both stdout and log file for debugging level := slog.LevelInfo - switch config.LogLevel { + switch b.config.LogLevel { case "debug": level = slog.LevelDebug case "warn": @@ -125,86 +168,126 @@ func New(config *Config) (*Server, error) { level = slog.LevelError } - // Create log file for persistent logging (used by 'dvb daemon logs') - logFilePath := filepath.Join(config.DataDir, "daemon.log") + logFilePath := filepath.Join(b.config.DataDir, "daemon.log") logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { - return nil, fmt.Errorf("failed to open daemon log file: %w", err) + b.err = fmt.Errorf("failed to open daemon log file: %w", err) + return b } - // Write logs to both stdout and file multiWriter := io.MultiWriter(os.Stdout, logFile) - logger := slog.New(slog.NewTextHandler(multiWriter, &slog.HandlerOptions{Level: level})) + b.server.logger = slog.New(slog.NewTextHandler(multiWriter, &slog.HandlerOptions{Level: level})) + b.server.logFile = logFile + b.pushCleanup(func() error { return logFile.Close() }) + return b +} + +// WithPlugins loads and registers network plugins. +func (b *ServerBuilder) WithPlugins() *ServerBuilder { + if b.err != nil { + return b + } - // Load network plugins from plugin directories - // Plugins are discovered from ~/.devnet-builder/plugins/ and registered - // with the global network registry so they can be queried via NetworkService pluginMgr := NewPluginManager(PluginManagerConfig{ - PluginDirs: []string{filepath.Join(config.DataDir, "plugins")}, - Logger: logger, + PluginDirs: []string{filepath.Join(b.config.DataDir, "plugins")}, + Logger: b.server.logger, }) result, err := pluginMgr.LoadAndRegister() if err != nil { - return nil, fmt.Errorf("failed to load plugins: %w", err) + b.err = fmt.Errorf("failed to load plugins: %w", err) + return b } if len(result.Loaded) > 0 { - logger.Info("network plugins loaded", + b.server.logger.Info("network plugins loaded", "count", len(result.Loaded), "plugins", result.Loaded) } if len(result.Errors) > 0 { for _, e := range result.Errors { - logger.Warn("plugin load error", + b.server.logger.Warn("plugin load error", "plugin", e.Name, "error", e.Error) } } - // Open state store - dbPath := filepath.Join(config.DataDir, "devnetd.db") + b.server.pluginManager = pluginMgr + b.pushCleanup(func() error { + pluginMgr.Close() + return nil + }) + return b +} + +// WithStoreAndSubnet initializes state storage and subnet allocation. +func (b *ServerBuilder) WithStoreAndSubnet() *ServerBuilder { + if b.err != nil { + return b + } + + dbPath := filepath.Join(b.config.DataDir, "devnetd.db") st, err := store.NewBoltStore(dbPath) if err != nil { - pluginMgr.Close() - return nil, fmt.Errorf("failed to open state store: %w", err) + b.err = fmt.Errorf("failed to open state store: %w", err) + return b } - // Initialize subnet allocator for loopback network aliasing - subnetAllocatorPath := filepath.Join(config.DataDir, "subnets.json") + b.server.store = st + b.pushCleanup(func() error { return st.Close() }) + + subnetAllocatorPath := filepath.Join(b.config.DataDir, "subnets.json") subnetAlloc, err := subnet.LoadOrCreate(subnetAllocatorPath) if err != nil { - st.Close() - pluginMgr.Close() - return nil, fmt.Errorf("failed to initialize subnet allocator: %w", err) + b.err = fmt.Errorf("failed to initialize subnet allocator: %w", err) + return b + } + b.server.subnetAllocator = subnetAlloc + b.server.logger.Info("subnet allocator initialized", "path", subnetAllocatorPath) + return b +} + +// WithControllers registers controller stack. +func (b *ServerBuilder) WithControllers() *ServerBuilder { + if b.err != nil { + return b } - logger.Info("subnet allocator initialized", "path", subnetAllocatorPath) - // Create controller manager + b.setupControllerManager() + b.setupOrchestratorAndProvisioner() + b.registerDevnetController() + b.registerHealthController() + b.registerUpgradeController() + b.registerTransactionController() + return b +} + +func (b *ServerBuilder) setupControllerManager() { mgr := controller.NewManager() - mgr.SetLogger(logger) - - // Create orchestrator factory for full provisioning flow (build, fork, init) - orchFactory := NewOrchestratorFactory(config.DataDir, logger) - - // Create devnet provisioner with orchestrator factory and subnet allocator - // The factory enables full provisioning (build, fork, init) before creating Node resources - // The subnet allocator assigns unique loopback subnets to each devnet - devnetProv := provisioner.NewDevnetProvisioner(st, provisioner.Config{ - DataDir: config.DataDir, - Logger: logger, - OrchestratorFactory: orchFactory, - SubnetAllocator: subnetAlloc, + mgr.SetLogger(b.server.logger) + b.server.manager = mgr +} + +func (b *ServerBuilder) setupOrchestratorAndProvisioner() { + b.orchFactory = NewOrchestratorFactory(b.config.DataDir, b.server.logger) + b.devnetProv = provisioner.NewDevnetProvisioner(b.server.store, provisioner.Config{ + DataDir: b.config.DataDir, + Logger: b.server.logger, + OrchestratorFactory: b.orchFactory, + SubnetAllocator: b.server.subnetAllocator, }) +} - // Register controllers - devnetCtrl := controller.NewDevnetController(st, devnetProv) - devnetCtrl.SetLogger(logger) - devnetCtrl.SetManager(mgr) - mgr.Register("devnets", devnetCtrl) +func (b *ServerBuilder) registerDevnetController() { + devnetCtrl := controller.NewDevnetController(b.server.store, b.devnetProv) + devnetCtrl.SetLogger(b.server.logger) + devnetCtrl.SetManager(b.server.manager) + b.server.manager.Register("devnets", devnetCtrl) + b.attachProvisionProgressReporter(devnetCtrl) +} - // Wire step progress reporter to broadcast provision logs to CLI clients - devnetProv.SetStepProgressReporterFactory(func(namespace, name string) ports.ProgressReporter { +func (b *ServerBuilder) attachProvisionProgressReporter(devnetCtrl *controller.DevnetController) { + b.devnetProv.SetStepProgressReporterFactory(func(namespace, name string) ports.ProgressReporter { return ports.ProgressFunc(func(step ports.StepProgress) { devnetCtrl.BroadcastProvisionLog(namespace, name, &controller.ProvisionLogEntry{ Timestamp: time.Now(), @@ -221,14 +304,44 @@ func New(config *Config) (*Server, error) { }) }) }) +} + +func (b *ServerBuilder) registerHealthController() { + healthChecker := checker.NewRPCHealthChecker(checker.Config{ + Logger: b.server.logger, + Timeout: b.config.HealthCheckTimeout, + }) + healthConfig := controller.DefaultHealthControllerConfig() + healthCtrl := controller.NewHealthController(b.server.store, healthChecker, b.server.manager, healthConfig) + healthCtrl.SetLogger(b.server.logger) + b.server.manager.Register("health", healthCtrl) + b.server.healthCtrl = healthCtrl +} - // Select node runtime based on RuntimeMode. - // Backward compat: --docker flag overrides RuntimeMode. - runtimeMode := config.RuntimeMode +func (b *ServerBuilder) registerUpgradeController() { + upgradeRuntime := upgrader.NewRuntime(b.server.store, upgrader.Config{Logger: b.server.logger}) + upgradeCtrl := controller.NewUpgradeController(b.server.store, upgradeRuntime) + upgradeCtrl.SetLogger(b.server.logger) + b.server.manager.Register("upgrades", upgradeCtrl) +} + +func (b *ServerBuilder) registerTransactionController() { + txCtrl := controller.NewTxController(b.server.store, nil) + txCtrl.SetLogger(b.server.logger) + b.server.manager.Register("transactions", txCtrl) +} + +// WithRuntime initializes node runtime and registers node controller. +func (b *ServerBuilder) WithRuntime() *ServerBuilder { + if b.err != nil { + return b + } + + runtimeMode := b.config.RuntimeMode if runtimeMode == "" { runtimeMode = "process" } - if config.EnableDocker { + if b.config.EnableDocker { runtimeMode = "docker" } @@ -236,142 +349,144 @@ func New(config *Config) (*Server, error) { switch runtimeMode { case "docker": dockerRuntime, err := runtime.NewDockerRuntime(runtime.DockerConfig{ - DefaultImage: config.DockerImage, - Logger: logger, + DefaultImage: b.config.DockerImage, + Logger: b.server.logger, }) if err != nil { - return nil, fmt.Errorf("failed to create docker runtime: %w", err) + b.err = fmt.Errorf("failed to create docker runtime: %w", err) + return b } nodeRuntime = dockerRuntime - logger.Info("docker runtime enabled", "image", config.DockerImage) + b.server.logger.Info("docker runtime enabled", "image", b.config.DockerImage) case "service": svcRuntime, err := runtime.NewServiceRuntime(runtime.ServiceRuntimeConfig{ - DataDir: config.DataDir, - Logger: logger, - PluginRuntimeProvider: orchFactory.AsPluginRuntimeProvider(), + DataDir: b.config.DataDir, + Logger: b.server.logger, + PluginRuntimeProvider: b.orchFactory.AsPluginRuntimeProvider(), }) if err != nil { - return nil, fmt.Errorf("failed to create service runtime: %w", err) + b.err = fmt.Errorf("failed to create service runtime: %w", err) + return b } nodeRuntime = svcRuntime - logger.Info("service runtime enabled (OS service manager)") - default: // "process" + b.server.logger.Info("service runtime enabled (OS service manager)") + default: nodeRuntime = runtime.NewProcessRuntime(runtime.ProcessRuntimeConfig{ - DataDir: config.DataDir, - Logger: logger, - PluginRuntimeProvider: orchFactory.AsPluginRuntimeProvider(), + DataDir: b.config.DataDir, + Logger: b.server.logger, + PluginRuntimeProvider: b.orchFactory.AsPluginRuntimeProvider(), }) - logger.Info("process runtime enabled for local mode") + b.server.logger.Info("process runtime enabled for local mode") } - nodeCtrl := controller.NewNodeController(st, nodeRuntime) - nodeCtrl.SetLogger(logger) - mgr.Register("nodes", nodeCtrl) + b.server.nodeRuntime = nodeRuntime + nodeCtrl := controller.NewNodeController(b.server.store, nodeRuntime) + nodeCtrl.SetLogger(b.server.logger) + b.server.manager.Register("nodes", nodeCtrl) + return b +} - // Create health checker - healthChecker := checker.NewRPCHealthChecker(checker.Config{ - Logger: logger, - Timeout: config.HealthCheckTimeout, - }) +// WithGRPCServices initializes gRPC server and registers services. +func (b *ServerBuilder) WithGRPCServices() *ServerBuilder { + if b.err != nil { + return b + } - // Create and register health controller - healthConfig := controller.DefaultHealthControllerConfig() - healthCtrl := controller.NewHealthController(st, healthChecker, mgr, healthConfig) - healthCtrl.SetLogger(logger) - mgr.Register("health", healthCtrl) + grpcServer := b.buildGRPCServer() + b.server.grpcServer = grpcServer - // Create upgrade runtime - upgradeRuntime := upgrader.NewRuntime(st, upgrader.Config{ - Logger: logger, - }) + networkSvc := b.newNetworkService() + anteHandler := ante.New(b.server.store, networkSvc) + shutdownCtx := b.initShutdownContext() + b.registerGRPCServices(grpcServer, networkSvc, anteHandler, shutdownCtx) + return b +} - // Create and register upgrade controller - upgradeCtrl := controller.NewUpgradeController(st, upgradeRuntime) - upgradeCtrl.SetLogger(logger) - mgr.Register("upgrades", upgradeCtrl) - - // Create and register transaction controller - // TxRuntime is nil for now - will be connected when network plugins are loaded - txCtrl := controller.NewTxController(st, nil) - txCtrl.SetLogger(logger) - mgr.Register("transactions", txCtrl) - - // Create gRPC server with optional auth interceptors for remote mode - var grpcServer *grpc.Server - if config.Listen != "" && config.AuthEnabled { - // Load API key store for authentication. - // NOTE: Keys are loaded once at startup. After creating or revoking keys - // with `devnetd keys create/revoke`, the server must be restarted for - // changes to take effect. Consider implementing hot-reload in the future. - keysFile := config.AuthKeysFile - if keysFile == "" { - keysFile = filepath.Join(config.DataDir, "api-keys.yaml") - } - keyStore := auth.NewFileKeyStore(keysFile) - if err := keyStore.Load(); err != nil { - logger.Warn("failed to load API keys, starting with empty key store", "error", err) - } +func (b *ServerBuilder) buildGRPCServer() *grpc.Server { + if b.config.Listen == "" || !b.config.AuthEnabled { + return grpc.NewServer() + } - // Create gRPC server with auth interceptors - grpcServer = grpc.NewServer( - grpc.ChainUnaryInterceptor(auth.NewAuthInterceptor(keyStore, IsLocalConnection)), - grpc.ChainStreamInterceptor(auth.NewStreamAuthInterceptor(keyStore, IsLocalConnection)), - ) - logger.Info("authentication enabled for remote connections") - } else { - grpcServer = grpc.NewServer() + keysFile := b.config.AuthKeysFile + if keysFile == "" { + keysFile = filepath.Join(b.config.DataDir, "api-keys.yaml") + } + keyStore := auth.NewFileKeyStore(keysFile) + if err := keyStore.Load(); err != nil { + b.server.logger.Warn("failed to load API keys, starting with empty key store", "error", err) } - // Create network service first (needed by ante handler) - githubFactory := NewDefaultGitHubClientFactory(config.DataDir, logger) - networkSvc := NewNetworkService(githubFactory) - networkSvc.SetLogger(logger) + b.server.logger.Info("authentication enabled for remote connections") + return grpc.NewServer( + grpc.ChainUnaryInterceptor(auth.NewAuthInterceptor(keyStore, IsLocalConnection)), + grpc.ChainStreamInterceptor(auth.NewStreamAuthInterceptor(keyStore, IsLocalConnection)), + ) +} - // Create ante handler for request validation - anteHandler := ante.New(st, networkSvc) +func (b *ServerBuilder) newNetworkService() *NetworkService { + githubFactory := NewDefaultGitHubClientFactory(b.config.DataDir, b.server.logger) + networkSvc := NewNetworkService(githubFactory) + networkSvc.SetLogger(b.server.logger) + return networkSvc +} - // Create shutdown context for terminating long-running streaming RPCs during shutdown. - // This context is cancelled before GracefulStop() to unblock streams that would - // otherwise prevent graceful shutdown (e.g., log streaming blocked waiting for input). +func (b *ServerBuilder) initShutdownContext() context.Context { shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + b.server.shutdownCtx = shutdownCtx + b.server.shutdownCancel = shutdownCancel + b.pushCleanup(func() error { + shutdownCancel() + return nil + }) + return shutdownCtx +} - // Register services - devnetSvc := NewDevnetServiceWithAnte(st, mgr, anteHandler, subnetAlloc, devnetProv) - devnetSvc.SetLogger(logger) +func (b *ServerBuilder) registerGRPCServices( + grpcServer *grpc.Server, + networkSvc *NetworkService, + anteHandler *ante.AnteHandler, + shutdownCtx context.Context, +) { + devnetSvc := NewDevnetServiceWithAnte(b.server.store, b.server.manager, anteHandler, b.server.subnetAllocator, b.devnetProv) + devnetSvc.SetLogger(b.server.logger) v1.RegisterDevnetServiceServer(grpcServer, devnetSvc) - nodeSvc := NewNodeServiceWithAnte(st, mgr, nodeRuntime, anteHandler, shutdownCtx) - nodeSvc.SetLogger(logger) + nodeSvc := NewNodeServiceWithAnte(b.server.store, b.server.manager, b.server.nodeRuntime, anteHandler, shutdownCtx) + nodeSvc.SetLogger(b.server.logger) v1.RegisterNodeServiceServer(grpcServer, nodeSvc) - upgradeSvc := NewUpgradeServiceWithAnte(st, mgr, anteHandler) - upgradeSvc.SetLogger(logger) + upgradeSvc := NewUpgradeServiceWithAnte(b.server.store, b.server.manager, anteHandler) + upgradeSvc.SetLogger(b.server.logger) v1.RegisterUpgradeServiceServer(grpcServer, upgradeSvc) - txSvc := NewTransactionService(st, mgr) - txSvc.SetLogger(logger) + txSvc := NewTransactionService(b.server.store, b.server.manager) + txSvc.SetLogger(b.server.logger) v1.RegisterTransactionServiceServer(grpcServer, txSvc) v1.RegisterNetworkServiceServer(grpcServer, networkSvc) + v1.RegisterAuthServiceServer(grpcServer, NewAuthService()) +} - // Register auth service for ping/whoami - authSvc := NewAuthService() - v1.RegisterAuthServiceServer(grpcServer, authSvc) - - return &Server{ - config: config, - store: st, - manager: mgr, - healthCtrl: healthCtrl, - pluginManager: pluginMgr, - subnetAllocator: subnetAlloc, - nodeRuntime: nodeRuntime, - grpcServer: grpcServer, - logger: logger, - logFile: logFile, - shutdownCtx: shutdownCtx, - shutdownCancel: shutdownCancel, - }, nil +// Build finalizes server construction. +func (b *ServerBuilder) Build() (*Server, error) { + if b.err != nil { + return nil, b.cleanupOnError(b.err) + } + b.cleanupStack = nil + return b.server, nil +} + +func (b *ServerBuilder) pushCleanup(fn func() error) { + b.cleanupStack = append(b.cleanupStack, fn) +} + +func (b *ServerBuilder) cleanupOnError(buildErr error) error { + var cleanupErr error + for i := len(b.cleanupStack) - 1; i >= 0; i-- { + cleanupErr = multierr.Append(cleanupErr, b.cleanupStack[i]()) + } + b.cleanupStack = nil + return multierr.Append(buildErr, cleanupErr) } // Run starts the server and blocks until shutdown. diff --git a/internal/daemon/server/server_builder_test.go b/internal/daemon/server/server_builder_test.go new file mode 100644 index 00000000..3684ed83 --- /dev/null +++ b/internal/daemon/server/server_builder_test.go @@ -0,0 +1,78 @@ +package server + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestServerBuilderCleanupOnError_LIFOAndMultierr(t *testing.T) { + builder := NewServerBuilder(&Config{DataDir: t.TempDir()}) + order := make([]string, 0, 2) + + builder.pushCleanup(func() error { + order = append(order, "first") + return errors.New("cleanup-first") + }) + builder.pushCleanup(func() error { + order = append(order, "second") + return errors.New("cleanup-second") + }) + builder.err = errors.New("build-failed") + + _, err := builder.Build() + if err == nil { + t.Fatal("expected build error") + } + + if len(order) != 2 || order[0] != "second" || order[1] != "first" { + t.Fatalf("cleanup order mismatch: %+v", order) + } + + msg := err.Error() + if !strings.Contains(msg, "build-failed") { + t.Fatalf("expected build error in aggregated message: %v", err) + } + if !strings.Contains(msg, "cleanup-first") || !strings.Contains(msg, "cleanup-second") { + t.Fatalf("expected cleanup errors in aggregated message: %v", err) + } +} + +func TestServerBuilderBuildSuccess_DoesNotRunCleanup(t *testing.T) { + builder := NewServerBuilder(&Config{DataDir: t.TempDir()}) + cleanupCalled := false + builder.pushCleanup(func() error { + cleanupCalled = true + return nil + }) + + server, err := builder.Build() + if err != nil { + t.Fatalf("Build() failed: %v", err) + } + if server == nil { + t.Fatal("expected server instance") + } + if cleanupCalled { + t.Fatal("cleanup must not run on successful build") + } + if len(builder.cleanupStack) != 0 { + t.Fatalf("cleanup stack should be cleared, got %d", len(builder.cleanupStack)) + } +} + +func TestServerBuilderWithDataDir(t *testing.T) { + baseDir := t.TempDir() + dataDir := filepath.Join(baseDir, "daemon-data") + + builder := NewServerBuilder(&Config{DataDir: dataDir}).WithDataDir() + if builder.err != nil { + t.Fatalf("WithDataDir failed: %v", builder.err) + } + + if _, err := os.Stat(dataDir); err != nil { + t.Fatalf("expected data dir to exist: %v", err) + } +}