11import sys
22
3- from sqlalchemy import Table
3+ from sqlalchemy import func , Table , select , delete
4+ from sqlalchemy .dialects .sqlite import Insert # used for upsert
45from sqlalchemy .exc import (
56 NoInspectionAvailable ,
67 NoSuchTableError ,
78)
89
9- from nxc .database import BaseDB
10+ from nxc .database import BaseDB , format_host_query
11+ from nxc .logger import nxc_logger
1012
1113
1214class database (BaseDB ):
1315 def __init__ (self , db_engine ):
14- self .CredentialsTable = None
16+ self .UsersTable = None
1517 self .HostsTable = None
1618
1719 super ().__init__ (db_engine )
1820
1921 @staticmethod
2022 def db_schema (db_conn ):
2123 db_conn .execute (
22- """CREATE TABLE "credentials " (
24+ """CREATE TABLE "users " (
2325 "id" integer PRIMARY KEY,
26+ "domain" text,
2427 "username" text,
25- "password" text
28+ "password" text,
29+ "credtype" text,
30+ "pillaged_from_hostid" integer,
31+ FOREIGN KEY(pillaged_from_hostid) REFERENCES hosts(id)
2632 )"""
2733 )
2834
@@ -31,14 +37,15 @@ def db_schema(db_conn):
3137 "id" integer PRIMARY KEY,
3238 "ip" text,
3339 "hostname" text,
34- "port" integer
40+ "domain" text,
41+ "os" text
3542 )"""
3643 )
3744
3845 def reflect_tables (self ):
3946 with self .db_engine .connect ():
4047 try :
41- self .CredentialsTable = Table ("credentials " , self .metadata , autoload_with = self .db_engine )
48+ self .UsersTable = Table ("users " , self .metadata , autoload_with = self .db_engine )
4249 self .HostsTable = Table ("hosts" , self .metadata , autoload_with = self .db_engine )
4350 except (NoInspectionAvailable , NoSuchTableError ):
4451 print (
@@ -49,3 +56,179 @@ def reflect_tables(self):
4956 [-] Then remove the nxc { self .protocol } DB (`rm -f { self .db_path } `) and run nxc to initialize the new DB"""
5057 )
5158 sys .exit ()
59+
60+ def add_host (self , ip , hostname , domain , os ):
61+ """Check if this host has already been added to the database, if not, add it in."""
62+ hosts = []
63+ updated_ids = []
64+
65+ q = select (self .HostsTable ).filter (self .HostsTable .c .ip == ip )
66+ results = self .db_execute (q ).all ()
67+
68+ # create new host
69+ if not results :
70+ new_host = {
71+ "ip" : ip ,
72+ "hostname" : hostname ,
73+ "domain" : domain ,
74+ "os" : os
75+ }
76+ hosts = [new_host ]
77+ # update existing hosts data
78+ else :
79+ for host in results :
80+ host_data = host ._asdict ()
81+ # only update column if it is being passed in
82+ if ip is not None :
83+ host_data ["ip" ] = ip
84+ if hostname is not None :
85+ host_data ["hostname" ] = hostname
86+ if domain is not None :
87+ host_data ["domain" ] = domain
88+ # only add host to be updated if it has changed
89+ if host_data not in hosts :
90+ hosts .append (host_data )
91+ updated_ids .append (host_data ["id" ])
92+ nxc_logger .debug (f"Update Hosts: { hosts } " )
93+
94+ # TODO: find a way to abstract this away to a single Upsert call
95+ q = Insert (self .HostsTable ) # .returning(self.HostsTable.c.id)
96+ update_columns = {col .name : col for col in q .excluded if col .name not in "id" }
97+ q = q .on_conflict_do_update (index_elements = self .HostsTable .primary_key , set_ = update_columns )
98+
99+ self .db_execute (q , hosts ) # .scalar()
100+ # we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
101+ if updated_ids :
102+ nxc_logger .debug (f"add_host() - Host IDs Updated: { updated_ids } " )
103+ return updated_ids
104+
105+ def add_credential (self , credtype , domain , username , password , pillaged_from = None ):
106+ """Check if this credential has already been added to the database, if not add it in."""
107+ credentials = []
108+ groups = []
109+
110+ if pillaged_from and not self .is_host_valid (pillaged_from ):
111+ nxc_logger .debug ("Invalid host" )
112+ return
113+
114+ q = select (self .UsersTable ).filter (
115+ func .lower (self .UsersTable .c .domain ) == func .lower (domain ),
116+ func .lower (self .UsersTable .c .username ) == func .lower (username ),
117+ func .lower (self .UsersTable .c .credtype ) == func .lower (credtype ),
118+ )
119+ results = self .db_execute (q ).all ()
120+
121+ # add new credential
122+ if not results :
123+ new_cred = {
124+ "credtype" : credtype ,
125+ "domain" : domain ,
126+ "username" : username ,
127+ "password" : password ,
128+ "pillaged_from" : pillaged_from ,
129+ }
130+ credentials = [new_cred ]
131+ # update existing cred data
132+ else :
133+ for creds in results :
134+ # this will include the id, so we don't touch it
135+ cred_data = creds ._asdict ()
136+ # only update column if it is being passed in
137+ if credtype is not None :
138+ cred_data ["credtype" ] = credtype
139+ if domain is not None :
140+ cred_data ["domain" ] = domain
141+ if username is not None :
142+ cred_data ["username" ] = username
143+ if password is not None :
144+ cred_data ["password" ] = password
145+ if pillaged_from is not None :
146+ cred_data ["pillaged_from" ] = pillaged_from
147+ # only add cred to be updated if it has changed
148+ if cred_data not in credentials :
149+ credentials .append (cred_data )
150+
151+ # TODO: find a way to abstract this away to a single Upsert call
152+ q_users = Insert (self .UsersTable ) # .returning(self.UsersTable.c.id)
153+ update_columns_users = {col .name : col for col in q_users .excluded if col .name not in "id" }
154+ q_users = q_users .on_conflict_do_update (index_elements = self .UsersTable .primary_key , set_ = update_columns_users )
155+ nxc_logger .debug (f"Adding credentials: { credentials } " )
156+
157+ self .db_execute (q_users , credentials ) # .scalar()
158+
159+ if groups :
160+ q_groups = Insert (self .GroupRelationsTable )
161+
162+ self .db_execute (q_groups , groups )
163+
164+ def remove_credentials (self , creds_id ):
165+ """Removes a credential ID from the database"""
166+ del_hosts = []
167+ for cred_id in creds_id :
168+ q = delete (self .UsersTable ).filter (self .UsersTable .c .id == cred_id )
169+ del_hosts .append (q )
170+ self .db_execute (q )
171+
172+ def is_credential_valid (self , credential_id ):
173+ """Check if this credential ID is valid."""
174+ q = select (self .UsersTable ).filter (
175+ self .UsersTable .c .id == credential_id ,
176+ self .UsersTable .c .password is not None ,
177+ )
178+ results = self .db_execute (q ).all ()
179+ return len (results ) > 0
180+
181+ def get_credentials (self , filter_term = None , cred_type = None ):
182+ """Return credentials from the database."""
183+ # if we're returning a single credential by ID
184+ if self .is_credential_valid (filter_term ):
185+ q = select (self .UsersTable ).filter (self .UsersTable .c .id == filter_term )
186+ elif cred_type :
187+ q = select (self .UsersTable ).filter (self .UsersTable .c .credtype == cred_type )
188+ # if we're filtering by username
189+ elif filter_term and filter_term != "" :
190+ like_term = func .lower (f"%{ filter_term } %" )
191+ q = select (self .UsersTable ).filter (func .lower (self .UsersTable .c .username ).like (like_term ))
192+ # otherwise return all credentials
193+ else :
194+ q = select (self .UsersTable )
195+
196+ return self .db_execute (q ).all ()
197+
198+ def get_credential (self , cred_type , domain , username , password ):
199+ q = select (self .UsersTable ).filter (
200+ self .UsersTable .c .domain == domain ,
201+ self .UsersTable .c .username == username ,
202+ self .UsersTable .c .password == password ,
203+ self .UsersTable .c .credtype == cred_type ,
204+ )
205+ results = self .db_execute (q ).first ()
206+ return results .id
207+
208+ def get_hosts (self , filter_term = None , domain = None ):
209+ """Return hosts from the database."""
210+ q = select (self .HostsTable )
211+
212+ # if we're returning a single host by ID
213+ if self .is_host_valid (filter_term ):
214+ q = q .filter (self .HostsTable .c .id == filter_term )
215+ results = self .db_execute (q ).first ()
216+ # all() returns a list, so we keep the return format the same so consumers don't have to guess
217+ return [results ]
218+ elif filter_term is not None and filter_term .startswith ("domain" ):
219+ domain = filter_term .split ()[1 ]
220+ like_term = func .lower (f"%{ domain } %" )
221+ q = q .filter (self .HostsTable .c .domain .like (like_term ))
222+ # if we're filtering by ip/hostname
223+ elif filter_term and filter_term != "" :
224+ q = format_host_query (q , filter_term , self .HostsTable )
225+
226+ results = self .db_execute (q ).all ()
227+ nxc_logger .debug (f"ldap hosts() - results: { results } " )
228+ return results
229+
230+ def is_host_valid (self , host_id ):
231+ """Check if this host ID is valid."""
232+ q = select (self .HostsTable ).filter (self .HostsTable .c .id == host_id )
233+ results = self .db_execute (q ).all ()
234+ return len (results ) > 0
0 commit comments