@@ -109,7 +109,6 @@ import (
109109func (h * TaskHandler ) Control (c * web.Context , req domain.TaskControlReq ) error {
110110 user := middleware .GetUser (c )
111111
112- // 验证 task 归属(必须在 ws.Accept 之前完成)
113112 task , _ , err := h .usecase .Info (c .Request ().Context (), user , req .ID )
114113 if err != nil {
115114 return err
@@ -123,24 +122,63 @@ func (h *TaskHandler) Control(c *web.Context, req domain.TaskControlReq) error {
123122 defer wsConn .Close ()
124123
125124 logger := h .logger .With ("task_id" , task .ID , "fn" , "task.control" )
125+ taskID := task .ID .String ()
126+
127+ // 连接建立:刷新空闲计时器
128+ if vm := task .VirtualMachine ; vm != nil {
129+ if err := h .idleRefresher .Refresh (c .Request ().Context (), vm .ID ); err != nil {
130+ logger .WarnContext (c .Request ().Context (), "failed to refresh idle timers on connect" , "error" , err )
131+ }
126132
127- h .controlConns .Add (task .ID .String (), wsConn )
128- defer h .controlConns .Remove (task .ID .String (), wsConn )
133+ // VM 处于休眠状态时自动恢复
134+ if vm .Status == taskflow .VirtualMachineStatusHibernated {
135+ go func () {
136+ if err := h .taskflow .VirtualMachiner ().Resume (c .Request ().Context (), & taskflow.ResumeVirtualMachineReq {
137+ HostID : vm .Host .InternalID ,
138+ UserID : task .UserID .String (),
139+ ID : vm .ID ,
140+ EnvironmentID : vm .EnvironmentID ,
141+ }); err != nil {
142+ logger .WarnContext (context .Background (), "failed to resume vm on control connect" , "error" , err )
143+ }
144+ }()
145+ }
146+ }
147+
148+ h .controlConns .Add (taskID , wsConn )
149+ defer func () {
150+ h .controlConns .Remove (taskID , wsConn )
151+ // 最后一个连接断开:刷新计时器(开始空闲倒计时)
152+ if vm := task .VirtualMachine ; vm != nil && ! h .controlConns .Has (taskID ) {
153+ ctx , cancel := context .WithTimeout (context .Background (), 15 * time .Second )
154+ defer cancel ()
155+ if err := h .idleRefresher .Refresh (ctx , vm .ID ); err != nil {
156+ logger .WarnContext (ctx , "failed to refresh idle timers on disconnect" , "error" , err )
157+ }
158+ }
159+ }()
129160
130161 g , ctx := errgroup .WithContext (c .Request ().Context ())
131162
132163 g .Go (func () error {
133- return h .controlPing (ctx , wsConn , task . ID . String () )
164+ return h .controlPing (ctx , wsConn , taskID )
134165 })
135166
136167 g .Go (func () error {
137168 return h .controlReadMessages (ctx , wsConn , logger , task )
138169 })
139170
140171 g .Go (func () error {
141- return h .controlSubscribeTaskEvents (ctx , wsConn , logger , task . ID . String () )
172+ return h .controlSubscribeTaskEvents (ctx , wsConn , logger , taskID )
142173 })
143174
175+ // 定期刷新空闲计时器,保持 VM 活跃
176+ if vm := task .VirtualMachine ; vm != nil {
177+ g .Go (func () error {
178+ return h .controlKeepAlive (ctx , vm .ID )
179+ })
180+ }
181+
144182 if err := g .Wait (); err != nil {
145183 logger .DebugContext (c .Request ().Context (), "control websocket closed" , "reason" , err )
146184 }
@@ -165,6 +203,22 @@ func (h *TaskHandler) controlPing(ctx context.Context, wsConn *ws.WebsocketManag
165203 }
166204}
167205
206+ // controlKeepAlive 定期刷新空闲计时器,防止 VM 被误判空闲
207+ func (h * TaskHandler ) controlKeepAlive (ctx context.Context , vmID string ) error {
208+ ticker := time .NewTicker (5 * time .Minute )
209+ defer ticker .Stop ()
210+ for {
211+ select {
212+ case <- ctx .Done ():
213+ return ctx .Err ()
214+ case <- ticker .C :
215+ if err := h .idleRefresher .Refresh (ctx , vmID ); err != nil {
216+ h .logger .WarnContext (ctx , "keepalive refresh failed" , "vmID" , vmID , "error" , err )
217+ }
218+ }
219+ }
220+ }
221+
168222// controlReadMessages 读取客户端消息并分发处理
169223func (h * TaskHandler ) controlReadMessages (ctx context.Context , wsConn * ws.WebsocketManager , logger * slog.Logger , task * domain.Task ) error {
170224 for {
0 commit comments