1- from pathlib import Path
1+ import sys
2+
3+ from sqlalchemy import Table , select , delete , func
24from sqlalchemy .dialects .sqlite import Insert
3- from sqlalchemy .orm import sessionmaker , scoped_session
4- from sqlalchemy import MetaData , Table , select , delete , func
55from sqlalchemy .exc import (
6- IllegalStateChangeError ,
76 NoInspectionAvailable ,
87 NoSuchTableError ,
98)
9+
10+ from nxc .database import BaseDB
1011from nxc .logger import nxc_logger
11- import sys
1212
1313
14- class database :
14+ class database ( BaseDB ) :
1515 def __init__ (self , db_engine ):
1616 self .CredentialsTable = None
1717 self .HostsTable = None
1818 self .LoggedinRelationsTable = None
1919
20- self .db_engine = db_engine
21- self .db_path = self .db_engine .url .database
22- self .protocol = Path (self .db_path ).stem .upper ()
23- self .metadata = MetaData ()
24- self .reflect_tables ()
25-
26- session_factory = sessionmaker (bind = self .db_engine , expire_on_commit = True )
27- Session = scoped_session (session_factory )
28- self .sess = Session ()
20+ super ().__init__ (db_engine )
2921
3022 @staticmethod
3123 def db_schema (db_conn ):
@@ -80,26 +72,13 @@ def reflect_tables(self):
8072 )
8173 sys .exit ()
8274
83- def shutdown_db (self ):
84- try :
85- self .sess .close ()
86- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
87- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
88- # this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
89- except IllegalStateChangeError as e :
90- nxc_logger .debug (f"Error while closing session db object: { e } " )
91-
92- def clear_database (self ):
93- for table in self .metadata .sorted_tables :
94- self .sess .execute (table .delete ())
95-
9675 def add_host (self , host , port , banner ):
9776 """Check if this host is already in the DB, if not add it"""
9877 hosts = []
9978 updated_ids = []
10079
10180 q = select (self .HostsTable ).filter (self .HostsTable .c .host == host )
102- results = self .sess . execute (q ).all ()
81+ results = self .db_execute (q ).all ()
10382
10483 # create new host
10584 if not results :
@@ -133,7 +112,7 @@ def add_host(self, host, port, banner):
133112 update_columns = {col .name : col for col in q .excluded if col .name not in "id" }
134113 q = q .on_conflict_do_update (index_elements = self .HostsTable .primary_key , set_ = update_columns )
135114
136- self .sess . execute (q , hosts ) # .scalar()
115+ self .db_execute (q , hosts ) # .scalar()
137116 # we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
138117 if updated_ids :
139118 nxc_logger .debug (f"add_host() - Host IDs Updated: { updated_ids } " )
@@ -143,8 +122,9 @@ def add_credential(self, username, password):
143122 """Check if this credential has already been added to the database, if not add it in."""
144123 credentials = []
145124
146- q = select (self .CredentialsTable ).filter (func .lower (self .CredentialsTable .c .username ) == func .lower (username ), func .lower (self .CredentialsTable .c .password ) == func .lower (password ))
147- results = self .sess .execute (q ).all ()
125+ q = select (self .CredentialsTable ).filter (func .lower (self .CredentialsTable .c .username ) == func .lower (username ),
126+ func .lower (self .CredentialsTable .c .password ) == func .lower (password ))
127+ results = self .db_execute (q ).all ()
148128
149129 # add new credential
150130 if not results :
@@ -170,10 +150,11 @@ def add_credential(self, username, password):
170150 # TODO: find a way to abstract this away to a single Upsert call
171151 q_users = Insert (self .CredentialsTable ) # .returning(self.CredentialsTable.c.id)
172152 update_columns_users = {col .name : col for col in q_users .excluded if col .name not in "id" }
173- q_users = q_users .on_conflict_do_update (index_elements = self .CredentialsTable .primary_key , set_ = update_columns_users )
153+ q_users = q_users .on_conflict_do_update (index_elements = self .CredentialsTable .primary_key ,
154+ set_ = update_columns_users )
174155 nxc_logger .debug (f"Adding credentials: { credentials } " )
175156
176- self .sess . execute (q_users , credentials ) # .scalar()
157+ self .db_execute (q_users , credentials ) # .scalar()
177158
178159 # hacky way to get cred_id since we can't use returning() yet
179160 if len (credentials ) == 1 :
@@ -187,23 +168,23 @@ def remove_credentials(self, creds_id):
187168 for cred_id in creds_id :
188169 q = delete (self .CredentialsTable ).filter (self .CredentialsTable .c .id == cred_id )
189170 del_hosts .append (q )
190- self .sess . execute (q )
171+ self .db_execute (q )
191172
192173 def is_credential_valid (self , credential_id ):
193174 """Check if this credential ID is valid."""
194175 q = select (self .CredentialsTable ).filter (
195176 self .CredentialsTable .c .id == credential_id ,
196177 self .CredentialsTable .c .password is not None ,
197178 )
198- results = self .sess . execute (q ).all ()
179+ results = self .db_execute (q ).all ()
199180 return len (results ) > 0
200181
201182 def get_credential (self , username , password ):
202183 q = select (self .CredentialsTable ).filter (
203184 self .CredentialsTable .c .username == username ,
204185 self .CredentialsTable .c .password == password ,
205186 )
206- results = self .sess . execute (q ).first ()
187+ results = self .db_execute (q ).first ()
207188 if results is not None :
208189 return results .id
209190
@@ -220,12 +201,12 @@ def get_credentials(self, filter_term=None):
220201 else :
221202 q = select (self .CredentialsTable )
222203
223- return self .sess . execute (q ).all ()
204+ return self .db_execute (q ).all ()
224205
225206 def is_host_valid (self , host_id ):
226207 """Check if this host ID is valid."""
227208 q = select (self .HostsTable ).filter (self .HostsTable .c .id == host_id )
228- results = self .sess . execute (q ).all ()
209+ results = self .db_execute (q ).all ()
229210 return len (results ) > 0
230211
231212 def get_hosts (self , filter_term = None ):
@@ -235,26 +216,26 @@ def get_hosts(self, filter_term=None):
235216 # if we're returning a single host by ID
236217 if self .is_host_valid (filter_term ):
237218 q = q .filter (self .HostsTable .c .id == filter_term )
238- results = self .sess . execute (q ).first ()
219+ results = self .db_execute (q ).first ()
239220 # all() returns a list, so we keep the return format the same so consumers don't have to guess
240221 return [results ]
241222 # if we're filtering by host
242223 elif filter_term and filter_term != "" :
243224 like_term = func .lower (f"%{ filter_term } %" )
244225 q = q .filter (self .HostsTable .c .host .like (like_term ))
245- results = self .sess . execute (q ).all ()
226+ results = self .db_execute (q ).all ()
246227 nxc_logger .debug (f"FTP get_hosts() - results: { results } " )
247228 return results
248229
249230 def is_user_valid (self , cred_id ):
250231 """Check if this User ID is valid."""
251232 q = select (self .CredentialsTable ).filter (self .CredentialsTable .c .id == cred_id )
252- results = self .sess . execute (q ).all ()
233+ results = self .db_execute (q ).all ()
253234 return len (results ) > 0
254235
255236 def get_user (self , username ):
256237 q = select (self .CredentialsTable ).filter (func .lower (self .CredentialsTable .c .username ) == func .lower (username ))
257- return self .sess . execute (q ).all ()
238+ return self .db_execute (q ).all ()
258239
259240 def get_users (self , filter_term = None ):
260241 q = select (self .CredentialsTable )
@@ -265,14 +246,14 @@ def get_users(self, filter_term=None):
265246 elif filter_term and filter_term != "" :
266247 like_term = func .lower (f"%{ filter_term } %" )
267248 q = q .filter (func .lower (self .CredentialsTable .c .username ).like (like_term ))
268- return self .sess . execute (q ).all ()
249+ return self .db_execute (q ).all ()
269250
270251 def add_loggedin_relation (self , cred_id , host_id ):
271252 relation_query = select (self .LoggedinRelationsTable ).filter (
272253 self .LoggedinRelationsTable .c .credid == cred_id ,
273254 self .LoggedinRelationsTable .c .hostid == host_id ,
274255 )
275- results = self .sess . execute (relation_query ).all ()
256+ results = self .db_execute (relation_query ).all ()
276257
277258 # only add one if one doesn't already exist
278259 if not results :
@@ -282,7 +263,7 @@ def add_loggedin_relation(self, cred_id, host_id):
282263 # TODO: find a way to abstract this away to a single Upsert call
283264 q = Insert (self .LoggedinRelationsTable ) # .returning(self.LoggedinRelationsTable.c.id)
284265
285- self .sess . execute (q , [relation ]) # .scalar()
266+ self .db_execute (q , [relation ]) # .scalar()
286267 inserted_id_results = self .get_loggedin_relations (cred_id , host_id )
287268 nxc_logger .debug (f"Checking if relation was added: { inserted_id_results } " )
288269 return inserted_id_results [0 ].id
@@ -295,15 +276,15 @@ def get_loggedin_relations(self, cred_id=None, host_id=None):
295276 q = q .filter (self .LoggedinRelationsTable .c .credid == cred_id )
296277 if host_id :
297278 q = q .filter (self .LoggedinRelationsTable .c .hostid == host_id )
298- return self .sess . execute (q ).all ()
279+ return self .db_execute (q ).all ()
299280
300281 def remove_loggedin_relations (self , cred_id = None , host_id = None ):
301282 q = delete (self .LoggedinRelationsTable )
302283 if cred_id :
303284 q = q .filter (self .LoggedinRelationsTable .c .credid == cred_id )
304285 elif host_id :
305286 q = q .filter (self .LoggedinRelationsTable .c .hostid == host_id )
306- self .sess . execute (q )
287+ self .db_execute (q )
307288
308289 def add_directory_listing (self , lir_id , data ):
309290 pass
0 commit comments