Skip to content

Commit b59d823

Browse files
authored
Merge pull request Pennyw0rth#400 from dazzgt/bugfix/db-interface-error
refactoring to fix InterfaceError of DB
2 parents 0da4cf8 + dadeff9 commit b59d823

10 files changed

Lines changed: 260 additions & 422 deletions

File tree

nxc/database.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
import sys
21
import configparser
32
import shutil
4-
from sqlalchemy import create_engine
5-
from sqlite3 import connect
3+
import sys
64
from os import mkdir
75
from os.path import exists
86
from os.path import join as path_join
7+
from pathlib import Path
8+
from sqlite3 import connect
9+
from threading import Lock
10+
11+
from sqlalchemy import create_engine, MetaData
12+
from sqlalchemy.exc import IllegalStateChangeError
13+
from sqlalchemy.orm import sessionmaker, scoped_session
914

1015
from nxc.loaders.protocolloader import ProtocolLoader
16+
from nxc.logger import nxc_logger
1117
from nxc.paths import WORKSPACE_DIR
1218

1319

@@ -103,3 +109,39 @@ def initialize_db():
103109

104110
# Even if the default workspace exists, we still need to check if every protocol has a database (in case of a new protocol)
105111
init_protocol_dbs("default")
112+
113+
114+
class BaseDB:
115+
def __init__(self, db_engine):
116+
self.db_engine = db_engine
117+
self.db_path = self.db_engine.url.database
118+
self.protocol = Path(self.db_path).stem.upper()
119+
self.metadata = MetaData()
120+
self.reflect_tables()
121+
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
122+
123+
session = scoped_session(session_factory)
124+
self.sess = session()
125+
self.lock = Lock()
126+
127+
def reflect_tables(self):
128+
raise NotImplementedError("Reflect tables not implemented")
129+
130+
def shutdown_db(self):
131+
try:
132+
self.sess.close()
133+
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
134+
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
135+
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
136+
except IllegalStateChangeError as e:
137+
nxc_logger.debug(f"Error while closing session db object: {e}")
138+
139+
def clear_database(self):
140+
for table in self.metadata.sorted_tables:
141+
self.db_execute(table.delete())
142+
143+
def db_execute(self, *args):
144+
self.lock.acquire()
145+
res = self.sess.execute(*args)
146+
self.lock.release()
147+
return res

nxc/protocols/ftp/database.py

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,23 @@
1-
from pathlib import Path
1+
import sys
2+
3+
from sqlalchemy import Table, select, delete, func
24
from sqlalchemy.dialects.sqlite import Insert
3-
from sqlalchemy.orm import sessionmaker, scoped_session
4-
from sqlalchemy import MetaData, Table, select, delete, func
55
from sqlalchemy.exc import (
6-
IllegalStateChangeError,
76
NoInspectionAvailable,
87
NoSuchTableError,
98
)
9+
10+
from nxc.database import BaseDB
1011
from nxc.logger import nxc_logger
11-
import sys
1212

1313

14-
class database:
14+
class database(BaseDB):
1515
def __init__(self, db_engine):
1616
self.CredentialsTable = None
1717
self.HostsTable = None
1818
self.LoggedinRelationsTable = None
1919

20-
self.db_engine = db_engine
21-
self.db_path = self.db_engine.url.database
22-
self.protocol = Path(self.db_path).stem.upper()
23-
self.metadata = MetaData()
24-
self.reflect_tables()
25-
26-
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
27-
Session = scoped_session(session_factory)
28-
self.sess = Session()
20+
super().__init__(db_engine)
2921

3022
@staticmethod
3123
def db_schema(db_conn):
@@ -80,26 +72,13 @@ def reflect_tables(self):
8072
)
8173
sys.exit()
8274

83-
def shutdown_db(self):
84-
try:
85-
self.sess.close()
86-
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
87-
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
88-
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
89-
except IllegalStateChangeError as e:
90-
nxc_logger.debug(f"Error while closing session db object: {e}")
91-
92-
def clear_database(self):
93-
for table in self.metadata.sorted_tables:
94-
self.sess.execute(table.delete())
95-
9675
def add_host(self, host, port, banner):
9776
"""Check if this host is already in the DB, if not add it"""
9877
hosts = []
9978
updated_ids = []
10079

10180
q = select(self.HostsTable).filter(self.HostsTable.c.host == host)
102-
results = self.sess.execute(q).all()
81+
results = self.db_execute(q).all()
10382

10483
# create new host
10584
if not results:
@@ -133,7 +112,7 @@ def add_host(self, host, port, banner):
133112
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
134113
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)
135114

136-
self.sess.execute(q, hosts) # .scalar()
115+
self.db_execute(q, hosts) # .scalar()
137116
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
138117
if updated_ids:
139118
nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}")
@@ -143,8 +122,9 @@ def add_credential(self, username, password):
143122
"""Check if this credential has already been added to the database, if not add it in."""
144123
credentials = []
145124

146-
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username), func.lower(self.CredentialsTable.c.password) == func.lower(password))
147-
results = self.sess.execute(q).all()
125+
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username),
126+
func.lower(self.CredentialsTable.c.password) == func.lower(password))
127+
results = self.db_execute(q).all()
148128

149129
# add new credential
150130
if not results:
@@ -170,10 +150,11 @@ def add_credential(self, username, password):
170150
# TODO: find a way to abstract this away to a single Upsert call
171151
q_users = Insert(self.CredentialsTable) # .returning(self.CredentialsTable.c.id)
172152
update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"}
173-
q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, set_=update_columns_users)
153+
q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key,
154+
set_=update_columns_users)
174155
nxc_logger.debug(f"Adding credentials: {credentials}")
175156

176-
self.sess.execute(q_users, credentials) # .scalar()
157+
self.db_execute(q_users, credentials) # .scalar()
177158

178159
# hacky way to get cred_id since we can't use returning() yet
179160
if len(credentials) == 1:
@@ -187,23 +168,23 @@ def remove_credentials(self, creds_id):
187168
for cred_id in creds_id:
188169
q = delete(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
189170
del_hosts.append(q)
190-
self.sess.execute(q)
171+
self.db_execute(q)
191172

192173
def is_credential_valid(self, credential_id):
193174
"""Check if this credential ID is valid."""
194175
q = select(self.CredentialsTable).filter(
195176
self.CredentialsTable.c.id == credential_id,
196177
self.CredentialsTable.c.password is not None,
197178
)
198-
results = self.sess.execute(q).all()
179+
results = self.db_execute(q).all()
199180
return len(results) > 0
200181

201182
def get_credential(self, username, password):
202183
q = select(self.CredentialsTable).filter(
203184
self.CredentialsTable.c.username == username,
204185
self.CredentialsTable.c.password == password,
205186
)
206-
results = self.sess.execute(q).first()
187+
results = self.db_execute(q).first()
207188
if results is not None:
208189
return results.id
209190

@@ -220,12 +201,12 @@ def get_credentials(self, filter_term=None):
220201
else:
221202
q = select(self.CredentialsTable)
222203

223-
return self.sess.execute(q).all()
204+
return self.db_execute(q).all()
224205

225206
def is_host_valid(self, host_id):
226207
"""Check if this host ID is valid."""
227208
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
228-
results = self.sess.execute(q).all()
209+
results = self.db_execute(q).all()
229210
return len(results) > 0
230211

231212
def get_hosts(self, filter_term=None):
@@ -235,26 +216,26 @@ def get_hosts(self, filter_term=None):
235216
# if we're returning a single host by ID
236217
if self.is_host_valid(filter_term):
237218
q = q.filter(self.HostsTable.c.id == filter_term)
238-
results = self.sess.execute(q).first()
219+
results = self.db_execute(q).first()
239220
# all() returns a list, so we keep the return format the same so consumers don't have to guess
240221
return [results]
241222
# if we're filtering by host
242223
elif filter_term and filter_term != "":
243224
like_term = func.lower(f"%{filter_term}%")
244225
q = q.filter(self.HostsTable.c.host.like(like_term))
245-
results = self.sess.execute(q).all()
226+
results = self.db_execute(q).all()
246227
nxc_logger.debug(f"FTP get_hosts() - results: {results}")
247228
return results
248229

249230
def is_user_valid(self, cred_id):
250231
"""Check if this User ID is valid."""
251232
q = select(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
252-
results = self.sess.execute(q).all()
233+
results = self.db_execute(q).all()
253234
return len(results) > 0
254235

255236
def get_user(self, username):
256237
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username))
257-
return self.sess.execute(q).all()
238+
return self.db_execute(q).all()
258239

259240
def get_users(self, filter_term=None):
260241
q = select(self.CredentialsTable)
@@ -265,14 +246,14 @@ def get_users(self, filter_term=None):
265246
elif filter_term and filter_term != "":
266247
like_term = func.lower(f"%{filter_term}%")
267248
q = q.filter(func.lower(self.CredentialsTable.c.username).like(like_term))
268-
return self.sess.execute(q).all()
249+
return self.db_execute(q).all()
269250

270251
def add_loggedin_relation(self, cred_id, host_id):
271252
relation_query = select(self.LoggedinRelationsTable).filter(
272253
self.LoggedinRelationsTable.c.credid == cred_id,
273254
self.LoggedinRelationsTable.c.hostid == host_id,
274255
)
275-
results = self.sess.execute(relation_query).all()
256+
results = self.db_execute(relation_query).all()
276257

277258
# only add one if one doesn't already exist
278259
if not results:
@@ -282,7 +263,7 @@ def add_loggedin_relation(self, cred_id, host_id):
282263
# TODO: find a way to abstract this away to a single Upsert call
283264
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
284265

285-
self.sess.execute(q, [relation]) # .scalar()
266+
self.db_execute(q, [relation]) # .scalar()
286267
inserted_id_results = self.get_loggedin_relations(cred_id, host_id)
287268
nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}")
288269
return inserted_id_results[0].id
@@ -295,15 +276,15 @@ def get_loggedin_relations(self, cred_id=None, host_id=None):
295276
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
296277
if host_id:
297278
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
298-
return self.sess.execute(q).all()
279+
return self.db_execute(q).all()
299280

300281
def remove_loggedin_relations(self, cred_id=None, host_id=None):
301282
q = delete(self.LoggedinRelationsTable)
302283
if cred_id:
303284
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
304285
elif host_id:
305286
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
306-
self.sess.execute(q)
287+
self.db_execute(q)
307288

308289
def add_directory_listing(self, lir_id, data):
309290
pass

nxc/protocols/ldap/database.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,20 @@
1-
from pathlib import Path
2-
from sqlalchemy.orm import sessionmaker, scoped_session
3-
from sqlalchemy import MetaData, Table
1+
import sys
2+
3+
from sqlalchemy import Table
44
from sqlalchemy.exc import (
5-
IllegalStateChangeError,
65
NoInspectionAvailable,
76
NoSuchTableError,
87
)
9-
from nxc.logger import nxc_logger
10-
import sys
8+
9+
from nxc.database import BaseDB
1110

1211

13-
class database:
12+
class database(BaseDB):
1413
def __init__(self, db_engine):
1514
self.CredentialsTable = None
1615
self.HostsTable = None
1716

18-
self.db_engine = db_engine
19-
self.db_path = self.db_engine.url.database
20-
self.protocol = Path(self.db_path).stem.upper()
21-
self.metadata = MetaData()
22-
self.reflect_tables()
23-
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
24-
25-
Session = scoped_session(session_factory)
26-
# this is still named "conn" when it is the session object; TODO: rename
27-
self.conn = Session()
17+
super().__init__(db_engine)
2818

2919
@staticmethod
3020
def db_schema(db_conn):
@@ -59,16 +49,3 @@ def reflect_tables(self):
5949
[-] Then remove the nxc {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB"""
6050
)
6151
sys.exit()
62-
63-
def shutdown_db(self):
64-
try:
65-
self.conn.close()
66-
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
67-
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
68-
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
69-
except IllegalStateChangeError as e:
70-
nxc_logger.debug(f"Error while closing session db object: {e}")
71-
72-
def clear_database(self):
73-
for table in self.metadata.sorted_tables:
74-
self.conn.execute(table.delete())

0 commit comments

Comments
 (0)