|
1 | | -import sys |
2 | 1 | import configparser |
3 | 2 | import shutil |
4 | | -from sqlalchemy import create_engine |
5 | | -from sqlite3 import connect |
| 3 | +import sys |
6 | 4 | from os import mkdir |
7 | 5 | from os.path import exists |
8 | 6 | from os.path import join as path_join |
| 7 | +from pathlib import Path |
| 8 | +from sqlite3 import connect |
| 9 | +from threading import Lock |
| 10 | + |
| 11 | +from sqlalchemy import create_engine, MetaData |
| 12 | +from sqlalchemy.exc import IllegalStateChangeError |
| 13 | +from sqlalchemy.orm import sessionmaker, scoped_session |
9 | 14 |
|
10 | 15 | from nxc.loaders.protocolloader import ProtocolLoader |
| 16 | +from nxc.logger import nxc_logger |
11 | 17 | from nxc.paths import WORKSPACE_DIR |
12 | 18 |
|
13 | 19 |
|
@@ -103,3 +109,39 @@ def initialize_db(): |
103 | 109 |
|
104 | 110 | # Even if the default workspace exists, we still need to check if every protocol has a database (in case of a new protocol) |
105 | 111 | init_protocol_dbs("default") |
| 112 | + |
| 113 | + |
| 114 | +class BaseDB: |
| 115 | + def __init__(self, db_engine): |
| 116 | + self.db_engine = db_engine |
| 117 | + self.db_path = self.db_engine.url.database |
| 118 | + self.protocol = Path(self.db_path).stem.upper() |
| 119 | + self.metadata = MetaData() |
| 120 | + self.reflect_tables() |
| 121 | + session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) |
| 122 | + |
| 123 | + session = scoped_session(session_factory) |
| 124 | + self.sess = session() |
| 125 | + self.lock = Lock() |
| 126 | + |
| 127 | + def reflect_tables(self): |
| 128 | + raise NotImplementedError("Reflect tables not implemented") |
| 129 | + |
| 130 | + def shutdown_db(self): |
| 131 | + try: |
| 132 | + self.sess.close() |
| 133 | + # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: |
| 134 | + # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and |
| 135 | + # this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5> |
| 136 | + except IllegalStateChangeError as e: |
| 137 | + nxc_logger.debug(f"Error while closing session db object: {e}") |
| 138 | + |
| 139 | + def clear_database(self): |
| 140 | + for table in self.metadata.sorted_tables: |
| 141 | + self.db_execute(table.delete()) |
| 142 | + |
| 143 | + def db_execute(self, *args): |
| 144 | + self.lock.acquire() |
| 145 | + res = self.sess.execute(*args) |
| 146 | + self.lock.release() |
| 147 | + return res |
0 commit comments