-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RomApp] Adding a database with SQLite #12313
base: master
Are you sure you want to change the base?
Conversation
we need to update the RomParameter.json without expensive computations if stuff is in the data_base
This equals running a single simulation as described by the ProjectParameters.json
if not have_tensorflow: | ||
err_msg = f'Tensorflow module not found. Please install Tensorflow in to use the "ann_enhanced" option.' | ||
raise Exception(err_msg) | ||
|
||
rom_nn_trainer = RomNeuralNetworkTrainer(self.general_rom_manager_parameters) | ||
model_name=self.general_rom_manager_parameters["ROM"]["ann_enhanced_settings"]["online"]["model_name"].GetString() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this system for testing the NN won't work very well anymore, as the networks are saved by their hash now. Maybe we could integrate this testing routine into the general Test() method and just evaluate the network that coincides with the given parameters, if it exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was changed, however to know the name of the model one needs both the mu_train and mu_test.
Currently this method is not called in the RomManger, when doing Fit or Test. Where would it be called?
analysis_stage_class = self._GetAnalysisStageClass(parameters_copy) | ||
simulation = self.CustomizeSimulation(analysis_stage_class,model,parameters_copy) | ||
simulation.Run() | ||
self.QoI_Fit_FOM.append(simulation.GetFinalData()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how this QoI_Fit_FOM is used, but in here it is only getting results for the snapshots that were not in the database already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, the QoI's are stored in the database as well.
phi = np.load(database_settings['phi_matrix'].GetString()) | ||
sigma_vec = np.load(database_settings['sigma_vector'].GetString())/np.sqrt(S_train.shape[1]) | ||
S_train = self.data_base.get_snapshots_matrix_from_database(self.mu_train, table_name=f'FOM_Fit') | ||
S_val = self.data_base.get_snapshots_matrix_from_database(self.mu_validation, table_name=f'FOM_Fit') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would be good to throw an exception if not all the snapshots for the datasets were found. In the case of the training set later we check for the RightBasis, so that will fail too anyways. But we don't have that for the validation set
else: | ||
print(f"No entry found for hash {hash_mu}") | ||
conn.close() | ||
return np.block(SnapshotsMatrix) if SnapshotsMatrix else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it should also return a bool stating the mu values that weren't found in the database. So that one may alert the user that the matrix isn't complete, and ask them to simulate the remaining ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added warning, but could not trigger that behaviour. The cases are check to exist both in Fit() and Test() in the RomManger. It will serve for debugging when adding more tables
Co-authored-by: NicolasSR <54904691+NicolasSR@users.noreply.github.com>
Co-authored-by: NicolasSR <54904691+NicolasSR@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments.
Also, if possible, make a little doc (couple of lines) for the functions that are missing one (specially in the RomDatabase
class)
self.table_names = table_definitions.keys() | ||
for table_name, table_sql in table_definitions.items(): | ||
try: | ||
cursor.execute(table_sql) | ||
conn.commit() | ||
print(f"Table {table_name} created successfully.") | ||
except sqlite3.OperationalError as e: | ||
print(f"Error creating table {table_name}: {e}") | ||
conn.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment in general, not especifically for this piece of code.
Whould it be possible to make the commit after all the querys have been generated? Also if this is the case, maybe is a good idea to use a context manager to atumatically commit the queries. Something like:
try:
with conn:
for table_name, table_sql in table_definitions.items():
cursor.execute(table_sql)
print(f"Table {table_name} created successfully.")
# Trasnactions are automatically commited here at the end of the connection context.
except sqlite3.OperationalError as e:
print(f"Error creating tables: {e}")
cursor = conn.cursor() | ||
|
||
if table_name == 'FOM': | ||
cursor.execute(f'INSERT INTO {table_name} (parameters, file_name) VALUES (?, ?)', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better if you can pass the table name as an extra token instead of mixing tokens and fstring
conn.commit() | ||
conn.close() | ||
|
||
if table_name == "Neural_Network" or table_name == 'QoI_FOM' or table_name == 'QoI_ROM' or table_name == 'QoI_HROM': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if table_name == "Neural_Network" or table_name == 'QoI_FOM' or table_name == 'QoI_ROM' or table_name == 'QoI_HROM': | |
if table_name in ['Neural_Network', 'QoI_FOM', 'QoI_ROM', 'QoI_HROM']: |
@@ -2,6 +2,8 @@ | |||
import importlib | |||
import json | |||
from pathlib import Path | |||
import types | |||
from KratosMultiphysics.RomApplication.rom_database import RomDatabase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from imports after importing kratos (or remove line 8)
@@ -0,0 +1,394 @@ | |||
import KratosMultiphysics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is overkill only for the logger as the normal prints are not using it anyway
setattr(self, f'ROMvsHROM_{case}', error_rom_hrom) | ||
setattr(self, f'FOMvsHROM_{case}', error_fom_hrom) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not super fan of using this, why don't you use a dictionary instead?
self.ROMvsHROM['Fit']
for example?
📝 Description
This PR adds a database for the rom simulations using the sqlite3 python library.
The RomDatabase is a property of the RomManager. The main change in the behaviour is that now all snapshots are added to the database and when the error is to be computed, the SVD of the solutions is to be calculated, the SVD residuals projected is to the calculated, etc, the database retrieves the required npy files from the respective npy_files folder and once the npy file is loaded into fast memory, the calculation proceeds with the previously existing logic.
The files generated by the database are
🆕 Changelog