diff --git a/msnoise/api.py b/msnoise/api.py index aba1f49..75e1736 100644 --- a/msnoise/api.py +++ b/msnoise/api.py @@ -1511,7 +1511,8 @@ def get_mwcs(session, station1, station2, filterid, components, date, return pd.DataFrame() -def get_results_all(session, station1, station2, filterid, components, dates): +def get_results_all(session, station1, station2, filterid, components, dates, + format="dataframe"): """ :type session: :class:`sqlalchemy.orm.session.Session` :param session: A :class:`~sqlalchemy.orm.session.Session` object, as @@ -1544,7 +1545,15 @@ def get_results_all(session, station1, station2, filterid, components, dates): if len(results): result = pd.concat(results) del results - return result + if format == "dataframe": + return result + elif format == "xarray": + taxis = get_t_axis(session) + times = result.index + dr = xr.DataArray(result, coords=[times, taxis], + dims=["times", "taxis"]).dropna("times", how="all") + dr.name = "CCF" + return dr.to_dataset() else: return pd.DataFrame() @@ -2292,6 +2301,19 @@ def xr_get_mwcs(station1, station2, components, filterid, mov_stack): def xr_save_dtt(station1, station2, components, filterid, mov_stack, dataframe): + """ + :param station1: string, name of station 1 + :param station2: string, name of station 2 + :param components: string, name of the components + :param filterid: int, filter id + :param mov_stack: int, number of days in the moving stack + :param dataframe: pandas DataFrame containing the data + :return: None + + This method saves the given data in a NetCDF file using the specified parameters. The file path is constructed based on the station names, components, filter id, and moving stack number + *. The data in the DataFrame is stacked, and the index is set to include "times" and "keys" as names. The column in the DataFrame is renamed to "DTT". A new or existing NetCDF file is + * opened using the given file path, and the stacked data is inserted or updated in the file. The resulting dataset is then saved and the file is closed. + """ fn = os.path.join("DTT2", "%02i" % filterid, "%03i_DAYS" % mov_stack, "%s" % components, @@ -2308,6 +2330,18 @@ def xr_save_dtt(station1, station2, components, filterid, mov_stack, dataframe): def xr_get_dtt(station1, station2, components, filterid, mov_stack): + """ + :param station1: The first station name + :param station2: The second station name + :param components: The components to be used + :param filterid: The filter ID + :param mov_stack: The movement stack + :return: The extracted data + + This method retrieves the DTT data from a NetCDF file based on the given inputs. It constructs the file path using the provided parameters and checks if the file exists. If the file + * does not exist, it raises a FileNotFoundError. Otherwise, it opens the NetCDF file and extracts the DTT variable as a dataframe. The dataframe is then rearranged and returned as the + * result. + """ fn = os.path.join("DTT2", "%02i" % filterid, "%03i_DAYS" % mov_stack, "%s" % components, @@ -2360,6 +2394,15 @@ def xr_get_dvv(components, filterid, mov_stack): def wavg(group, dttname, errname): + """ + Calculate the weighted average of a given group using the provided parameters. + + :param group: A pandas DataFrame or Series representing the group of data. + :param dttname: The name of the column containing the data to be averaged. + :param errname: The name of the column containing the error for each data point. + :return: The weighted average of the data. + + """ d = group[dttname] group[errname][group[errname] == 0] = 1e-6 w = 1. / group[errname] @@ -2371,6 +2414,25 @@ def wavg(group, dttname, errname): def wstd(group, dttname, errname): + """ + :param group: A dictionary containing data for different groups. + :param dttname: The key in the `group` dictionary that corresponds to the data array. + :param errname: The key in the `group` dictionary that corresponds to the error array. + :return: The weighted standard deviation of the data array. + + This method calculates the weighted standard deviation of the data array specified by `dttname` in the `group` dictionary. + The weights are derived from the error array specified by `errname` in the `group` dictionary. + + The weighted standard deviation is computed using the following formula: + wstd = sqrt(sum(w * (d - wavg) ** 2) / ((N - 1) * sum(w) / N)) + where: + - d is the data array specified by `dttname` in the `group` dictionary. + - w is the weight array derived from the error array specified by `errname` in the `group` dictionary. + - wavg is the weighted average of the data array. + - N is the number of non-zero weights. + + Note: This method uses the `np` module from NumPy. + """ d = group[dttname] group[errname][group[errname] == 0] = 1e-6 w = 1. / group[errname] @@ -2381,13 +2443,36 @@ def wstd(group, dttname, errname): def get_wavgwstd(data, dttname, errname): + """ + Calculate the weighted average and weighted standard deviation for a given data. + + :param data: The data to calculate the weighted average and weighted standard deviation. + :type data: pandas.DataFrame + + :param dttname: The name of the column in the data frame containing the weights for the weighted average and weighted standard deviation calculation. + :type dttname: str + + :param errname: The name of the column in the data frame containing the errors on the data. + :type errname: str + + :return: A tuple containing the calculated weighted average and weighted standard deviation. + :rtype: tuple + """ grouped = data.groupby(level=0) g = grouped.apply(wavg, dttname=dttname, errname=errname) h = grouped.apply(wstd, dttname=dttname, errname=errname) return g, h -def trim(data, dttname, errname, limits=0.1): +def trim(data, dttname, limits=0.1): + """ + Trimmed mean and standard deviation calculation. + + :param data: DataFrame containing the data. + :param dttname: Name of the column used for grouping. + :param limits: Trimming limits (default is 0.1). + :return: Tuple containing the trimmed mean and trimmed standard deviation. + """ from scipy.stats.mstats import trimmed_mean, trimmed_std grouped = data[dttname].groupby(level=0) if limits == 0: