Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 67 additions & 64 deletions api/libs/oauth_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pydantic import TypeAdapter
from sqlalchemy import select

from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding

Expand Down Expand Up @@ -95,27 +95,28 @@ def get_access_token(self, code: str) -> None:
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()

def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
Expand All @@ -130,55 +131,57 @@ def save_internal_access_token(self, access_token: str) -> None:
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
db.session.add(new_data_source_binding)
db.session.commit()
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()

def sync_data_source(self, binding_id: str) -> None:
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
Comment thread
asukaminato0721 marked this conversation as resolved.
)
)

if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
Comment thread
asukaminato0721 marked this conversation as resolved.
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
raise ValueError("Data source binding not found")

def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []
Expand Down
119 changes: 60 additions & 59 deletions api/schedule/mail_clean_document_notify_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import app
from configs import dify_config
from core.db.session_factory import session_factory
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
from models import Account, Tenant, TenantAccountJoin
Expand All @@ -33,67 +33,68 @@ def mail_clean_document_notify_task():

# send document clean notify mail
try:
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
features = FeatureService.get_features(tenant_id)
plan = features.billing.subscription.plan
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
# check current owner
current_owner_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if not current_owner_join:
continue
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
if not account:
continue

dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
with session_factory.create_session() as session:
dataset_auto_disable_logs = session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False))
).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
features = FeatureService.get_features(tenant_id)
plan = features.billing.subscription.plan
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
# check current owner
current_owner_join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if not current_owner_join:
continue
account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
if not account:
continue

for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
language_code="en-US",
to=account.email,
template_context={
"userName": account.email,
"knowledge_details": knowledge_details,
"url": url,
},
)
dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)

# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_disable_log.notified = True
db.session.commit()
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
Comment thread
asukaminato0721 marked this conversation as resolved.
language_code="en-US",
to=account.email,
template_context={
"userName": account.email,
"knowledge_details": knowledge_details,
"url": url,
},
)

# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_disable_log.notified = True
session.commit()
Comment thread
asukaminato0721 marked this conversation as resolved.
end_at = time.perf_counter()
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
Expand Down
Loading