Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RomApp] Adding a database with SQLite #12313

Open
wants to merge 64 commits into
base: master
Choose a base branch
from

Conversation

Rbravo555
Copy link
Member

@Rbravo555 Rbravo555 commented Apr 25, 2024

📝 Description
This PR adds a database for the rom simulations using the sqlite3 python library.

The RomDatabase is a property of the RomManager. The main change in the behaviour is that now all snapshots are added to the database and when the error is to be computed, the SVD of the solutions is to be calculated, the SVD residuals projected is to the calculated, etc, the database retrieves the required npy files from the respective npy_files folder and once the npy file is loaded into fast memory, the calculation proceeds with the previously existing logic.

The files generated by the database are

  • a .db file that is read by the sqlite3 library
  • a folder containing a human readable summary in excel format, here the complete database can also be exported if an optional method is called
  • a folder containing the .npy files. The database actually contains a hash, that is a kind of key created using the parameters used to generate the file, therefore each snapshot, right basis, left basis, residuals projected, all have a unique hash. The RomDatabase class contains methods to retrieve each of them.
  • a folder containing the neural networks trained for the ANN-Enhanced PROM

🆕 Changelog

  • Added rom_database.py script containing the RomDataBase class
  • Modified the methods Fit and Test in the RomManager to work with the new database
  • Test has been added

if not have_tensorflow:
err_msg = f'Tensorflow module not found. Please install Tensorflow in to use the "ann_enhanced" option.'
raise Exception(err_msg)

rom_nn_trainer = RomNeuralNetworkTrainer(self.general_rom_manager_parameters)
model_name=self.general_rom_manager_parameters["ROM"]["ann_enhanced_settings"]["online"]["model_name"].GetString()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this system for testing the NN won't work very well anymore, as the networks are saved by their hash now. Maybe we could integrate this testing routine into the general Test() method and just evaluate the network that coincides with the given parameters, if it exists.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was changed, however to know the name of the model one needs both the mu_train and mu_test.

Currently this method is not called in the RomManger, when doing Fit or Test. Where would it be called?

analysis_stage_class = self._GetAnalysisStageClass(parameters_copy)
simulation = self.CustomizeSimulation(analysis_stage_class,model,parameters_copy)
simulation.Run()
self.QoI_Fit_FOM.append(simulation.GetFinalData())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how this QoI_Fit_FOM is used, but in here it is only getting results for the snapshots that were not in the database already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, the QoI's are stored in the database as well.

phi = np.load(database_settings['phi_matrix'].GetString())
sigma_vec = np.load(database_settings['sigma_vector'].GetString())/np.sqrt(S_train.shape[1])
S_train = self.data_base.get_snapshots_matrix_from_database(self.mu_train, table_name=f'FOM_Fit')
S_val = self.data_base.get_snapshots_matrix_from_database(self.mu_validation, table_name=f'FOM_Fit')
Copy link
Contributor

@NicolasSR NicolasSR May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be good to throw an exception if not all the snapshots for the datasets were found. In the case of the training set later we check for the RightBasis, so that will fail too anyways. But we don't have that for the validation set

else:
print(f"No entry found for hash {hash_mu}")
conn.close()
return np.block(SnapshotsMatrix) if SnapshotsMatrix else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should also return a bool stating the mu values that weren't found in the database. So that one may alert the user that the matrix isn't complete, and ask them to simulate the remaining ones.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added warning, but could not trigger that behaviour. The cases are check to exist both in Fit() and Test() in the RomManger. It will serve for debugging when adding more tables

Copy link
Member

@roigcarlo roigcarlo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments.

Also, if possible, make a little doc (couple of lines) for the functions that are missing one (specially in the RomDatabase class)

Comment on lines +81 to +89
self.table_names = table_definitions.keys()
for table_name, table_sql in table_definitions.items():
try:
cursor.execute(table_sql)
conn.commit()
print(f"Table {table_name} created successfully.")
except sqlite3.OperationalError as e:
print(f"Error creating table {table_name}: {e}")
conn.close()
Copy link
Member

@roigcarlo roigcarlo May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment in general, not especifically for this piece of code.

Whould it be possible to make the commit after all the querys have been generated? Also if this is the case, maybe is a good idea to use a context manager to atumatically commit the queries. Something like:

try:
    with conn:
        for table_name, table_sql in table_definitions.items():
            cursor.execute(table_sql)
            print(f"Table {table_name} created successfully.")
    # Trasnactions are automatically commited here at the end of the connection context.
except sqlite3.OperationalError as e:
    print(f"Error creating tables: {e}")

cursor = conn.cursor()

if table_name == 'FOM':
cursor.execute(f'INSERT INTO {table_name} (parameters, file_name) VALUES (?, ?)',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better if you can pass the table name as an extra token instead of mixing tokens and fstring

conn.commit()
conn.close()

if table_name == "Neural_Network" or table_name == 'QoI_FOM' or table_name == 'QoI_ROM' or table_name == 'QoI_HROM':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if table_name == "Neural_Network" or table_name == 'QoI_FOM' or table_name == 'QoI_ROM' or table_name == 'QoI_HROM':
if table_name in ['Neural_Network', 'QoI_FOM', 'QoI_ROM', 'QoI_HROM']:

@@ -2,6 +2,8 @@
import importlib
import json
from pathlib import Path
import types
from KratosMultiphysics.RomApplication.rom_database import RomDatabase
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from imports after importing kratos (or remove line 8)

@@ -0,0 +1,394 @@
import KratosMultiphysics
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is overkill only for the logger as the normal prints are not using it anyway

Comment on lines +253 to +254
setattr(self, f'ROMvsHROM_{case}', error_rom_hrom)
setattr(self, f'FOMvsHROM_{case}', error_fom_hrom)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super fan of using this, why don't you use a dictionary instead?

self.ROMvsHROM['Fit']

for example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
ROM Application Project
  
In progress
Development

Successfully merging this pull request may close these issues.

None yet

7 participants