Skip to content

Commit 52dd6c0

Browse files
committed
feat: 创建任务通知补齐信息
1 parent c11551d commit 52dd6c0

File tree

3 files changed

+63
-22
lines changed

3 files changed

+63
-22
lines changed

backend/biz/task/repo/task.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ func (t *TaskRepo) StatByIDs(ctx context.Context, ids []uuid.UUID) (map[uuid.UUI
107107
// GetByID implements domain.TaskRepo.
108108
func (t *TaskRepo) GetByID(ctx context.Context, id uuid.UUID) (*db.Task, error) {
109109
return t.db.Task.Query().
110+
WithUser().
110111
WithProjectTasks(func(ptq *db.ProjectTaskQuery) {
111112
ptq.
112113
WithModel().

backend/biz/task/usecase/task.go

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ func (a *TaskUsecase) Create(ctx context.Context, user *domain.User, req domain.
342342

343343
ctx = entx.WithTaskConcurrencyLimit(ctx, limit)
344344

345+
var createdVm *taskflow.VirtualMachine
345346
pt, err := a.repo.Create(ctx, user, req, token, func(pt *db.ProjectTask, m *db.Model, i *db.Image) (*taskflow.VirtualMachine, error) {
346347
t := pt.Edges.Task
347348
if t == nil {
@@ -387,6 +388,7 @@ func (a *TaskUsecase) Create(ctx context.Context, user *domain.User, req domain.
387388
if vm == nil {
388389
return nil, fmt.Errorf("vm is nil")
389390
}
391+
createdVm = vm
390392

391393
mcps := []taskflow.McpServerConfig{
392394
{
@@ -407,23 +409,6 @@ func (a *TaskUsecase) Create(ctx context.Context, user *domain.User, req domain.
407409
},
408410
}
409411

410-
taskMeta := lifecycle.TaskMetadata{
411-
TaskID: t.ID,
412-
UserID: user.ID,
413-
}
414-
if err := a.taskLifecycle.Transition(ctx, t.ID, consts.TaskStatusPending, taskMeta); err != nil {
415-
a.logger.WarnContext(ctx, "task lifecycle transition failed", "error", err)
416-
}
417-
418-
vmMeta := lifecycle.VMMetadata{
419-
VMID: vm.ID,
420-
TaskID: &t.ID,
421-
UserID: user.ID,
422-
}
423-
if err := a.vmLifecycle.Transition(ctx, vm.ID, lifecycle.VMStatePending, vmMeta); err != nil {
424-
a.logger.WarnContext(ctx, "vm lifecycle transition failed", "error", err)
425-
}
426-
427412
// 存储 CreateTaskReq 到 Redis(10 分钟过期),供 Lifecycle Manager 消费
428413
createTaskReq := &taskflow.CreateTaskReq{
429414
ID: t.ID,
@@ -456,6 +441,24 @@ func (a *TaskUsecase) Create(ctx context.Context, user *domain.User, req domain.
456441
return nil, err
457442
}
458443
a.logger.With("req", req).InfoContext(ctx, "task created")
444+
taskMeta := lifecycle.TaskMetadata{
445+
TaskID: pt.TaskID,
446+
UserID: user.ID,
447+
}
448+
if err := a.taskLifecycle.Transition(ctx, pt.TaskID, consts.TaskStatusPending, taskMeta); err != nil {
449+
a.logger.WarnContext(ctx, "task lifecycle transition failed", "error", err)
450+
}
451+
452+
if createdVm != nil {
453+
vmMeta := lifecycle.VMMetadata{
454+
VMID: createdVm.ID,
455+
TaskID: &pt.TaskID,
456+
UserID: user.ID,
457+
}
458+
if err := a.vmLifecycle.Transition(ctx, createdVm.ID, lifecycle.VMStatePending, vmMeta); err != nil {
459+
a.logger.WarnContext(ctx, "vm lifecycle transition failed", "error", err)
460+
}
461+
}
459462

460463
result := cvt.From(pt, &domain.ProjectTask{})
461464

backend/pkg/lifecycle/tasknotifyhook.go

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package lifecycle
33
import (
44
"context"
55
"log/slog"
6+
"time"
67

78
"github.com/google/uuid"
89
"github.com/samber/do"
@@ -16,13 +17,15 @@ import (
1617
type TaskNotifyHook struct {
1718
notify *dispatcher.Dispatcher
1819
logger *slog.Logger
20+
repo domain.TaskRepo
1921
}
2022

2123
// NewTaskNotifyHook 创建任务通知 Hook
2224
func NewTaskNotifyHook(i *do.Injector) *TaskNotifyHook {
2325
return &TaskNotifyHook{
2426
notify: do.MustInvoke[*dispatcher.Dispatcher](i),
2527
logger: do.MustInvoke[*slog.Logger](i).With("hook", "task-notify-hook"),
28+
repo: do.MustInvoke[domain.TaskRepo](i),
2629
}
2730
}
2831

@@ -38,17 +41,51 @@ func (h *TaskNotifyHook) OnStateChange(ctx context.Context, taskID uuid.UUID, fr
3841
default:
3942
return nil
4043
}
44+
logger := h.logger.With("task_id", taskID, "from", from, "to", to)
45+
46+
task, err := h.repo.GetByID(ctx, taskID)
47+
if err != nil {
48+
logger.With("error", err).ErrorContext(ctx, "failed to get task on state change")
49+
return err
50+
}
51+
52+
payload := domain.NotifyEventPayload{
53+
TaskID: taskID.String(),
54+
TaskStatus: string(to),
55+
TaskContent: task.Content,
56+
VMName: "",
57+
VMArch: "",
58+
VMCores: 0,
59+
VMMemory: 0,
60+
VMOS: "",
61+
}
62+
if u := task.Edges.User; u != nil {
63+
payload.UserName = u.Name
64+
}
65+
if pts := task.Edges.ProjectTasks; len(pts) > 0 {
66+
pt := pts[0]
67+
if m := pt.Edges.Model; m != nil {
68+
payload.ModelName = m.Model
69+
}
70+
}
71+
if vms := task.Edges.Vms; len(vms) > 0 {
72+
vm := vms[0]
73+
payload.VMID = vm.ID
74+
payload.VMName = vm.Name
75+
payload.VMArch = vm.Arch
76+
payload.VMCores = vm.Cores
77+
payload.VMMemory = vm.Memory
78+
payload.VMOS = vm.Os
79+
}
4180

4281
event := &domain.NotifyEvent{
4382
EventType: eventType,
4483
SubjectUserID: metadata.UserID,
4584
RefID: taskID.String(),
46-
Payload: domain.NotifyEventPayload{
47-
TaskID: taskID.String(),
48-
TaskStatus: string(to),
49-
},
85+
Payload: payload,
86+
OccurredAt: time.Now(),
5087
}
5188

52-
h.logger.InfoContext(ctx, "publishing notify event", "event", eventType, "task_id", taskID)
89+
logger.InfoContext(ctx, "publishing notify event", "event", eventType)
5390
return h.notify.Publish(ctx, event)
5491
}

0 commit comments

Comments
 (0)