@@ -19,6 +19,7 @@ import (
1919 "github.com/chaitin/MonkeyCode/backend/db/model"
2020 "github.com/chaitin/MonkeyCode/backend/db/projecttask"
2121 "github.com/chaitin/MonkeyCode/backend/db/task"
22+ "github.com/chaitin/MonkeyCode/backend/db/taskusagestat"
2223 "github.com/chaitin/MonkeyCode/backend/db/user"
2324 "github.com/chaitin/MonkeyCode/backend/db/virtualmachine"
2425 "github.com/chaitin/MonkeyCode/backend/domain"
@@ -44,16 +45,63 @@ func NewTaskRepo(i *do.Injector) (domain.TaskRepo, error) {
4445 }, nil
4546}
4647
48+ type statsById struct {
49+ ID uuid.UUID `json:"task_id"`
50+ InputTokens int64 `json:"input_tokens"`
51+ OutputTokens int64 `json:"output_tokens"`
52+ TotalTokens int64 `json:"total_tokens"`
53+ LLMRequests int64 `json:"llm_requests"`
54+ }
55+
4756// Stat implements domain.TaskRepo.
48- // 开源版本无 TaskUsageStat 表,返回空结果
49- func (t * TaskRepo ) Stat (_ context.Context , _ uuid.UUID ) (* domain.TaskStats , error ) {
50- return & domain.TaskStats {}, nil
57+ func (t * TaskRepo ) Stat (ctx context.Context , id uuid.UUID ) (* domain.TaskStats , error ) {
58+ var results []* domain.TaskStats
59+ err := t .db .TaskUsageStat .Query ().
60+ Where (taskusagestat .TaskIDEQ (id )).
61+ Aggregate (
62+ db .As (db .Sum (taskusagestat .FieldInputTokens ), "input_tokens" ),
63+ db .As (db .Sum (taskusagestat .FieldOutputTokens ), "output_tokens" ),
64+ db .As (db .Sum (taskusagestat .FieldTotalTokens ), "total_tokens" ),
65+ db .As (db .Count (), "llm_requests" ),
66+ ).
67+ Scan (ctx , & results )
68+ if err != nil {
69+ return nil , err
70+ }
71+ if len (results ) > 0 {
72+ return results [0 ], nil
73+ }
74+ return nil , nil
5175}
5276
5377// StatByIDs implements domain.TaskRepo.
54- // 开源版本无 TaskUsageStat 表,返回空 map
55- func (t * TaskRepo ) StatByIDs (_ context.Context , _ []uuid.UUID ) (map [uuid.UUID ]* domain.TaskStats , error ) {
56- return make (map [uuid.UUID ]* domain.TaskStats ), nil
78+ func (t * TaskRepo ) StatByIDs (ctx context.Context , ids []uuid.UUID ) (map [uuid.UUID ]* domain.TaskStats , error ) {
79+ var results []* statsById
80+ err := t .db .TaskUsageStat .Query ().
81+ Where (taskusagestat .TaskIDIn (ids ... )).
82+ Modify (func (s * sql.Selector ) {
83+ s .Select (
84+ "task_id" ,
85+ sql .As (sql .Sum (s .C (taskusagestat .FieldInputTokens )), "input_tokens" ),
86+ sql .As (sql .Sum (s .C (taskusagestat .FieldOutputTokens )), "output_tokens" ),
87+ sql .As (sql .Sum (s .C (taskusagestat .FieldTotalTokens )), "total_tokens" ),
88+ sql .As (sql .Count ("*" ), "llm_requests" ),
89+ ).
90+ GroupBy (s .C (taskusagestat .FieldTaskID ))
91+ }).
92+ Scan (ctx , & results )
93+ if err != nil {
94+ return nil , err
95+ }
96+
97+ return cvt .IterToMap (results , func (_ int , s * statsById ) (uuid.UUID , * domain.TaskStats ) {
98+ return s .ID , & domain.TaskStats {
99+ InputTokens : s .InputTokens ,
100+ OutputTokens : s .OutputTokens ,
101+ TotalTokens : s .TotalTokens ,
102+ LLMRequests : s .LLMRequests ,
103+ }
104+ }), nil
57105}
58106
59107// GetByID implements domain.TaskRepo.
0 commit comments