Skip to content

Commit da0dfb4

Browse files
yokowuclaude
andcommitted
feat: control 流集成 VMIdleRefresher,连接时自动 Refresh/Resume
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6ecb67c commit da0dfb4

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

backend/biz/task/handler/v1/task.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/samber/do"
1919

2020
"github.com/chaitin/MonkeyCode/backend/biz/task/service"
21+
vmidle "github.com/chaitin/MonkeyCode/backend/biz/vmidle/usecase"
2122
"github.com/chaitin/MonkeyCode/backend/config"
2223
"github.com/chaitin/MonkeyCode/backend/consts"
2324
"github.com/chaitin/MonkeyCode/backend/domain"
@@ -41,7 +42,8 @@ type TaskHandler struct {
4142
nls *nls.NLS
4243
taskConns *ws.TaskConn
4344
controlConns *ws.ControlConn
44-
taskSummary *service.TaskSummaryService
45+
taskSummary *service.TaskSummaryService
46+
idleRefresher vmidle.VMIdleRefresher
4547
}
4648

4749
// NewTaskHandler 创建任务处理器
@@ -57,6 +59,7 @@ func NewTaskHandler(i *do.Injector) (*TaskHandler, error) {
5759
tc := do.MustInvoke[*ws.TaskConn](i)
5860
cc := do.MustInvoke[*ws.ControlConn](i)
5961
ts := do.MustInvoke[*service.TaskSummaryService](i)
62+
ir := do.MustInvoke[vmidle.VMIdleRefresher](i)
6063

6164
// Optional deps
6265
var pubhost domain.PublicHostUsecase
@@ -81,6 +84,7 @@ func NewTaskHandler(i *do.Injector) (*TaskHandler, error) {
8184
taskConns: tc,
8285
controlConns: cc,
8386
taskSummary: ts,
87+
idleRefresher: ir,
8488
}
8589

8690
// 注册路由

backend/biz/task/handler/v1/task_control.go

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ import (
109109
func (h *TaskHandler) Control(c *web.Context, req domain.TaskControlReq) error {
110110
user := middleware.GetUser(c)
111111

112-
// 验证 task 归属(必须在 ws.Accept 之前完成)
113112
task, _, err := h.usecase.Info(c.Request().Context(), user, req.ID)
114113
if err != nil {
115114
return err
@@ -123,24 +122,63 @@ func (h *TaskHandler) Control(c *web.Context, req domain.TaskControlReq) error {
123122
defer wsConn.Close()
124123

125124
logger := h.logger.With("task_id", task.ID, "fn", "task.control")
125+
taskID := task.ID.String()
126+
127+
// 连接建立:刷新空闲计时器
128+
if vm := task.VirtualMachine; vm != nil {
129+
if err := h.idleRefresher.Refresh(c.Request().Context(), vm.ID); err != nil {
130+
logger.WarnContext(c.Request().Context(), "failed to refresh idle timers on connect", "error", err)
131+
}
126132

127-
h.controlConns.Add(task.ID.String(), wsConn)
128-
defer h.controlConns.Remove(task.ID.String(), wsConn)
133+
// VM 处于休眠状态时自动恢复
134+
if vm.Status == taskflow.VirtualMachineStatusHibernated {
135+
go func() {
136+
if err := h.taskflow.VirtualMachiner().Resume(c.Request().Context(), &taskflow.ResumeVirtualMachineReq{
137+
HostID: vm.Host.InternalID,
138+
UserID: task.UserID.String(),
139+
ID: vm.ID,
140+
EnvironmentID: vm.EnvironmentID,
141+
}); err != nil {
142+
logger.WarnContext(context.Background(), "failed to resume vm on control connect", "error", err)
143+
}
144+
}()
145+
}
146+
}
147+
148+
h.controlConns.Add(taskID, wsConn)
149+
defer func() {
150+
h.controlConns.Remove(taskID, wsConn)
151+
// 最后一个连接断开:刷新计时器(开始空闲倒计时)
152+
if vm := task.VirtualMachine; vm != nil && !h.controlConns.Has(taskID) {
153+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
154+
defer cancel()
155+
if err := h.idleRefresher.Refresh(ctx, vm.ID); err != nil {
156+
logger.WarnContext(ctx, "failed to refresh idle timers on disconnect", "error", err)
157+
}
158+
}
159+
}()
129160

130161
g, ctx := errgroup.WithContext(c.Request().Context())
131162

132163
g.Go(func() error {
133-
return h.controlPing(ctx, wsConn, task.ID.String())
164+
return h.controlPing(ctx, wsConn, taskID)
134165
})
135166

136167
g.Go(func() error {
137168
return h.controlReadMessages(ctx, wsConn, logger, task)
138169
})
139170

140171
g.Go(func() error {
141-
return h.controlSubscribeTaskEvents(ctx, wsConn, logger, task.ID.String())
172+
return h.controlSubscribeTaskEvents(ctx, wsConn, logger, taskID)
142173
})
143174

175+
// 定期刷新空闲计时器,保持 VM 活跃
176+
if vm := task.VirtualMachine; vm != nil {
177+
g.Go(func() error {
178+
return h.controlKeepAlive(ctx, vm.ID)
179+
})
180+
}
181+
144182
if err := g.Wait(); err != nil {
145183
logger.DebugContext(c.Request().Context(), "control websocket closed", "reason", err)
146184
}
@@ -165,6 +203,22 @@ func (h *TaskHandler) controlPing(ctx context.Context, wsConn *ws.WebsocketManag
165203
}
166204
}
167205

206+
// controlKeepAlive 定期刷新空闲计时器,防止 VM 被误判空闲
207+
func (h *TaskHandler) controlKeepAlive(ctx context.Context, vmID string) error {
208+
ticker := time.NewTicker(5 * time.Minute)
209+
defer ticker.Stop()
210+
for {
211+
select {
212+
case <-ctx.Done():
213+
return ctx.Err()
214+
case <-ticker.C:
215+
if err := h.idleRefresher.Refresh(ctx, vmID); err != nil {
216+
h.logger.WarnContext(ctx, "keepalive refresh failed", "vmID", vmID, "error", err)
217+
}
218+
}
219+
}
220+
}
221+
168222
// controlReadMessages 读取客户端消息并分发处理
169223
func (h *TaskHandler) controlReadMessages(ctx context.Context, wsConn *ws.WebsocketManager, logger *slog.Logger, task *domain.Task) error {
170224
for {

0 commit comments

Comments
 (0)