Skip to content

Commit a438b27

Browse files
committed
chore(api): migrate mail task and OAuth data source to use Session(db.engine)
1 parent 7de92c5 commit a438b27

File tree

2 files changed

+127
-123
lines changed

2 files changed

+127
-123
lines changed

api/libs/oauth_data_source.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pydantic import TypeAdapter
77
from sqlalchemy import select
88

9+
from core.db.session_factory import session_factory
910
from core.helper.http_client_pooling import get_pooled_http_client
10-
from extensions.ext_database import db
1111
from libs.datetime_utils import naive_utc_now
1212
from models.source import DataSourceOauthBinding
1313

@@ -95,27 +95,28 @@ def get_access_token(self, code: str) -> None:
9595
pages=pages,
9696
)
9797
# save data source binding
98-
data_source_binding = db.session.scalar(
99-
select(DataSourceOauthBinding).where(
100-
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
101-
DataSourceOauthBinding.provider == "notion",
102-
DataSourceOauthBinding.access_token == access_token,
98+
with session_factory.create_session() as session:
99+
data_source_binding = session.scalar(
100+
select(DataSourceOauthBinding).where(
101+
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
102+
DataSourceOauthBinding.provider == "notion",
103+
DataSourceOauthBinding.access_token == access_token,
104+
)
103105
)
104-
)
105-
if data_source_binding:
106-
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
107-
data_source_binding.disabled = False
108-
data_source_binding.updated_at = naive_utc_now()
109-
db.session.commit()
110-
else:
111-
new_data_source_binding = DataSourceOauthBinding(
112-
tenant_id=current_user.current_tenant_id,
113-
access_token=access_token,
114-
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
115-
provider="notion",
116-
)
117-
db.session.add(new_data_source_binding)
118-
db.session.commit()
106+
if data_source_binding:
107+
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
108+
data_source_binding.disabled = False
109+
data_source_binding.updated_at = naive_utc_now()
110+
session.commit()
111+
else:
112+
new_data_source_binding = DataSourceOauthBinding(
113+
tenant_id=current_user.current_tenant_id,
114+
access_token=access_token,
115+
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
116+
provider="notion",
117+
)
118+
session.add(new_data_source_binding)
119+
session.commit()
119120

120121
def save_internal_access_token(self, access_token: str) -> None:
121122
workspace_name = self.notion_workspace_name(access_token)
@@ -130,55 +131,57 @@ def save_internal_access_token(self, access_token: str) -> None:
130131
pages=pages,
131132
)
132133
# save data source binding
133-
data_source_binding = db.session.scalar(
134-
select(DataSourceOauthBinding).where(
135-
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
136-
DataSourceOauthBinding.provider == "notion",
137-
DataSourceOauthBinding.access_token == access_token,
138-
)
139-
)
140-
if data_source_binding:
141-
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
142-
data_source_binding.disabled = False
143-
data_source_binding.updated_at = naive_utc_now()
144-
db.session.commit()
145-
else:
146-
new_data_source_binding = DataSourceOauthBinding(
147-
tenant_id=current_user.current_tenant_id,
148-
access_token=access_token,
149-
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
150-
provider="notion",
134+
with session_factory.create_session() as session:
135+
data_source_binding = session.scalar(
136+
select(DataSourceOauthBinding).where(
137+
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
138+
DataSourceOauthBinding.provider == "notion",
139+
DataSourceOauthBinding.access_token == access_token,
140+
)
151141
)
152-
db.session.add(new_data_source_binding)
153-
db.session.commit()
142+
if data_source_binding:
143+
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
144+
data_source_binding.disabled = False
145+
data_source_binding.updated_at = naive_utc_now()
146+
session.commit()
147+
else:
148+
new_data_source_binding = DataSourceOauthBinding(
149+
tenant_id=current_user.current_tenant_id,
150+
access_token=access_token,
151+
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
152+
provider="notion",
153+
)
154+
session.add(new_data_source_binding)
155+
session.commit()
154156

155157
def sync_data_source(self, binding_id: str) -> None:
156158
# save data source binding
157-
data_source_binding = db.session.scalar(
158-
select(DataSourceOauthBinding).where(
159-
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
160-
DataSourceOauthBinding.provider == "notion",
161-
DataSourceOauthBinding.id == binding_id,
162-
DataSourceOauthBinding.disabled == False,
159+
with session_factory.create_session() as session:
160+
data_source_binding = session.scalar(
161+
select(DataSourceOauthBinding).where(
162+
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
163+
DataSourceOauthBinding.provider == "notion",
164+
DataSourceOauthBinding.id == binding_id,
165+
DataSourceOauthBinding.disabled == False,
166+
)
163167
)
164-
)
165168

166-
if data_source_binding:
167-
# get all authorized pages
168-
pages = self.get_authorized_pages(data_source_binding.access_token)
169-
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
170-
new_source_info = self._build_source_info(
171-
workspace_name=source_info["workspace_name"],
172-
workspace_icon=source_info["workspace_icon"],
173-
workspace_id=source_info["workspace_id"],
174-
pages=pages,
175-
)
176-
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
177-
data_source_binding.disabled = False
178-
data_source_binding.updated_at = naive_utc_now()
179-
db.session.commit()
180-
else:
181-
raise ValueError("Data source binding not found")
169+
if data_source_binding:
170+
# get all authorized pages
171+
pages = self.get_authorized_pages(data_source_binding.access_token)
172+
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
173+
new_source_info = self._build_source_info(
174+
workspace_name=source_info["workspace_name"],
175+
workspace_icon=source_info["workspace_icon"],
176+
workspace_id=source_info["workspace_id"],
177+
pages=pages,
178+
)
179+
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
180+
data_source_binding.disabled = False
181+
data_source_binding.updated_at = naive_utc_now()
182+
session.commit()
183+
else:
184+
raise ValueError("Data source binding not found")
182185

183186
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
184187
pages: list[NotionPageSummary] = []

api/schedule/mail_clean_document_notify_task.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import app
99
from configs import dify_config
10+
from core.db.session_factory import session_factory
1011
from enums.cloud_plan import CloudPlan
11-
from extensions.ext_database import db
1212
from extensions.ext_mail import mail
1313
from libs.email_i18n import EmailType, get_email_i18n_service
1414
from models import Account, Tenant, TenantAccountJoin
@@ -33,67 +33,68 @@ def mail_clean_document_notify_task():
3333

3434
# send document clean notify mail
3535
try:
36-
dataset_auto_disable_logs = db.session.scalars(
37-
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
38-
).all()
39-
# group by tenant_id
40-
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
41-
for dataset_auto_disable_log in dataset_auto_disable_logs:
42-
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
43-
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
44-
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
45-
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
46-
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
47-
features = FeatureService.get_features(tenant_id)
48-
plan = features.billing.subscription.plan
49-
if plan != CloudPlan.SANDBOX:
50-
knowledge_details = []
51-
# check tenant
52-
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
53-
if not tenant:
54-
continue
55-
# check current owner
56-
current_owner_join = db.session.scalar(
57-
select(TenantAccountJoin)
58-
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
59-
.limit(1)
60-
)
61-
if not current_owner_join:
62-
continue
63-
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
64-
if not account:
65-
continue
66-
67-
dataset_auto_dataset_map = {} # type: ignore
68-
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
69-
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
70-
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
71-
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
72-
dataset_auto_disable_log.document_id
36+
with session_factory.create_session() as session:
37+
dataset_auto_disable_logs = session.scalars(
38+
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
39+
).all()
40+
# group by tenant_id
41+
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
42+
for dataset_auto_disable_log in dataset_auto_disable_logs:
43+
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
44+
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
45+
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
46+
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
47+
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
48+
features = FeatureService.get_features(tenant_id)
49+
plan = features.billing.subscription.plan
50+
if plan != CloudPlan.SANDBOX:
51+
knowledge_details = []
52+
# check tenant
53+
tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
54+
if not tenant:
55+
continue
56+
# check current owner
57+
current_owner_join = session.scalar(
58+
select(TenantAccountJoin)
59+
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
60+
.limit(1)
7361
)
62+
if not current_owner_join:
63+
continue
64+
account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
65+
if not account:
66+
continue
7467

75-
for dataset_id, document_ids in dataset_auto_dataset_map.items():
76-
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
77-
if dataset:
78-
document_count = len(document_ids)
79-
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
80-
if knowledge_details:
81-
email_service = get_email_i18n_service()
82-
email_service.send_email(
83-
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
84-
language_code="en-US",
85-
to=account.email,
86-
template_context={
87-
"userName": account.email,
88-
"knowledge_details": knowledge_details,
89-
"url": url,
90-
},
91-
)
68+
dataset_auto_dataset_map = {} # type: ignore
69+
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
70+
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
71+
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
72+
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
73+
dataset_auto_disable_log.document_id
74+
)
9275

93-
# update notified to True
94-
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
95-
dataset_auto_disable_log.notified = True
96-
db.session.commit()
76+
for dataset_id, document_ids in dataset_auto_dataset_map.items():
77+
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
78+
if dataset:
79+
document_count = len(document_ids)
80+
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
81+
if knowledge_details:
82+
email_service = get_email_i18n_service()
83+
email_service.send_email(
84+
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
85+
language_code="en-US",
86+
to=account.email,
87+
template_context={
88+
"userName": account.email,
89+
"knowledge_details": knowledge_details,
90+
"url": url,
91+
},
92+
)
93+
94+
# update notified to True
95+
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
96+
dataset_auto_disable_log.notified = True
97+
session.commit()
9798
end_at = time.perf_counter()
9899
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
99100
except Exception:

0 commit comments

Comments
 (0)