66from pydantic import TypeAdapter
77from sqlalchemy import select
88
9+ from core .db .session_factory import session_factory
910from core .helper .http_client_pooling import get_pooled_http_client
10- from extensions .ext_database import db
1111from libs .datetime_utils import naive_utc_now
1212from models .source import DataSourceOauthBinding
1313
@@ -95,27 +95,28 @@ def get_access_token(self, code: str) -> None:
9595 pages = pages ,
9696 )
9797 # save data source binding
98- data_source_binding = db .session .scalar (
99- select (DataSourceOauthBinding ).where (
100- DataSourceOauthBinding .tenant_id == current_user .current_tenant_id ,
101- DataSourceOauthBinding .provider == "notion" ,
102- DataSourceOauthBinding .access_token == access_token ,
98+ with session_factory .create_session () as session :
99+ data_source_binding = session .scalar (
100+ select (DataSourceOauthBinding ).where (
101+ DataSourceOauthBinding .tenant_id == current_user .current_tenant_id ,
102+ DataSourceOauthBinding .provider == "notion" ,
103+ DataSourceOauthBinding .access_token == access_token ,
104+ )
103105 )
104- )
105- if data_source_binding :
106- data_source_binding .source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info )
107- data_source_binding .disabled = False
108- data_source_binding .updated_at = naive_utc_now ()
109- db .session .commit ()
110- else :
111- new_data_source_binding = DataSourceOauthBinding (
112- tenant_id = current_user .current_tenant_id ,
113- access_token = access_token ,
114- source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info ),
115- provider = "notion" ,
116- )
117- db .session .add (new_data_source_binding )
118- db .session .commit ()
106+ if data_source_binding :
107+ data_source_binding .source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info )
108+ data_source_binding .disabled = False
109+ data_source_binding .updated_at = naive_utc_now ()
110+ session .commit ()
111+ else :
112+ new_data_source_binding = DataSourceOauthBinding (
113+ tenant_id = current_user .current_tenant_id ,
114+ access_token = access_token ,
115+ source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info ),
116+ provider = "notion" ,
117+ )
118+ session .add (new_data_source_binding )
119+ session .commit ()
119120
120121 def save_internal_access_token (self , access_token : str ) -> None :
121122 workspace_name = self .notion_workspace_name (access_token )
@@ -130,55 +131,57 @@ def save_internal_access_token(self, access_token: str) -> None:
130131 pages = pages ,
131132 )
132133 # save data source binding
133- data_source_binding = db .session .scalar (
134- select (DataSourceOauthBinding ).where (
135- DataSourceOauthBinding .tenant_id == current_user .current_tenant_id ,
136- DataSourceOauthBinding .provider == "notion" ,
137- DataSourceOauthBinding .access_token == access_token ,
138- )
139- )
140- if data_source_binding :
141- data_source_binding .source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info )
142- data_source_binding .disabled = False
143- data_source_binding .updated_at = naive_utc_now ()
144- db .session .commit ()
145- else :
146- new_data_source_binding = DataSourceOauthBinding (
147- tenant_id = current_user .current_tenant_id ,
148- access_token = access_token ,
149- source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info ),
150- provider = "notion" ,
134+ with session_factory .create_session () as session :
135+ data_source_binding = session .scalar (
136+ select (DataSourceOauthBinding ).where (
137+ DataSourceOauthBinding .tenant_id == current_user .current_tenant_id ,
138+ DataSourceOauthBinding .provider == "notion" ,
139+ DataSourceOauthBinding .access_token == access_token ,
140+ )
151141 )
152- db .session .add (new_data_source_binding )
153- db .session .commit ()
142+ if data_source_binding :
143+ data_source_binding .source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info )
144+ data_source_binding .disabled = False
145+ data_source_binding .updated_at = naive_utc_now ()
146+ session .commit ()
147+ else :
148+ new_data_source_binding = DataSourceOauthBinding (
149+ tenant_id = current_user .current_tenant_id ,
150+ access_token = access_token ,
151+ source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (source_info ),
152+ provider = "notion" ,
153+ )
154+ session .add (new_data_source_binding )
155+ session .commit ()
154156
155157 def sync_data_source (self , binding_id : str ) -> None :
156158 # save data source binding
157- data_source_binding = db .session .scalar (
158- select (DataSourceOauthBinding ).where (
159- DataSourceOauthBinding .tenant_id == current_user .current_tenant_id ,
160- DataSourceOauthBinding .provider == "notion" ,
161- DataSourceOauthBinding .id == binding_id ,
162- DataSourceOauthBinding .disabled == False ,
159+ with session_factory .create_session () as session :
160+ data_source_binding = session .scalar (
161+ select (DataSourceOauthBinding ).where (
162+ DataSourceOauthBinding .tenant_id == current_user .current_tenant_id ,
163+ DataSourceOauthBinding .provider == "notion" ,
164+ DataSourceOauthBinding .id == binding_id ,
165+ DataSourceOauthBinding .disabled == False ,
166+ )
163167 )
164- )
165168
166- if data_source_binding :
167- # get all authorized pages
168- pages = self .get_authorized_pages (data_source_binding .access_token )
169- source_info = NOTION_SOURCE_INFO_ADAPTER .validate_python (data_source_binding .source_info )
170- new_source_info = self ._build_source_info (
171- workspace_name = source_info ["workspace_name" ],
172- workspace_icon = source_info ["workspace_icon" ],
173- workspace_id = source_info ["workspace_id" ],
174- pages = pages ,
175- )
176- data_source_binding .source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (new_source_info )
177- data_source_binding .disabled = False
178- data_source_binding .updated_at = naive_utc_now ()
179- db . session .commit ()
180- else :
181- raise ValueError ("Data source binding not found" )
169+ if data_source_binding :
170+ # get all authorized pages
171+ pages = self .get_authorized_pages (data_source_binding .access_token )
172+ source_info = NOTION_SOURCE_INFO_ADAPTER .validate_python (data_source_binding .source_info )
173+ new_source_info = self ._build_source_info (
174+ workspace_name = source_info ["workspace_name" ],
175+ workspace_icon = source_info ["workspace_icon" ],
176+ workspace_id = source_info ["workspace_id" ],
177+ pages = pages ,
178+ )
179+ data_source_binding .source_info = SOURCE_INFO_STORAGE_ADAPTER .validate_python (new_source_info )
180+ data_source_binding .disabled = False
181+ data_source_binding .updated_at = naive_utc_now ()
182+ session .commit ()
183+ else :
184+ raise ValueError ("Data source binding not found" )
182185
183186 def get_authorized_pages (self , access_token : str ) -> list [NotionPageSummary ]:
184187 pages : list [NotionPageSummary ] = []
0 commit comments