Skip to content

Commit

Permalink
preparing the way to make use of subdaily CCFs in the standard workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasLecocq committed Feb 23, 2024
1 parent 2bffce5 commit 149a781
Showing 1 changed file with 88 additions and 3 deletions.
91 changes: 88 additions & 3 deletions msnoise/api.py
Expand Up @@ -1511,7 +1511,8 @@ def get_mwcs(session, station1, station2, filterid, components, date,
return pd.DataFrame()


def get_results_all(session, station1, station2, filterid, components, dates):
def get_results_all(session, station1, station2, filterid, components, dates,
format="dataframe"):
"""
:type session: :class:`sqlalchemy.orm.session.Session`
:param session: A :class:`~sqlalchemy.orm.session.Session` object, as
Expand Down Expand Up @@ -1544,7 +1545,15 @@ def get_results_all(session, station1, station2, filterid, components, dates):
if len(results):
result = pd.concat(results)
del results
return result
if format == "dataframe":
return result
elif format == "xarray":
taxis = get_t_axis(session)
times = result.index
dr = xr.DataArray(result, coords=[times, taxis],
dims=["times", "taxis"]).dropna("times", how="all")
dr.name = "CCF"
return dr.to_dataset()
else:
return pd.DataFrame()

Expand Down Expand Up @@ -2292,6 +2301,19 @@ def xr_get_mwcs(station1, station2, components, filterid, mov_stack):


def xr_save_dtt(station1, station2, components, filterid, mov_stack, dataframe):
"""
:param station1: string, name of station 1
:param station2: string, name of station 2
:param components: string, name of the components
:param filterid: int, filter id
:param mov_stack: int, number of days in the moving stack
:param dataframe: pandas DataFrame containing the data
:return: None
This method saves the given data in a NetCDF file using the specified parameters. The file path is constructed based on the station names, components, filter id, and moving stack number
*. The data in the DataFrame is stacked, and the index is set to include "times" and "keys" as names. The column in the DataFrame is renamed to "DTT". A new or existing NetCDF file is
* opened using the given file path, and the stacked data is inserted or updated in the file. The resulting dataset is then saved and the file is closed.
"""
fn = os.path.join("DTT2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s" % components,
Expand All @@ -2308,6 +2330,18 @@ def xr_save_dtt(station1, station2, components, filterid, mov_stack, dataframe):


def xr_get_dtt(station1, station2, components, filterid, mov_stack):
"""
:param station1: The first station name
:param station2: The second station name
:param components: The components to be used
:param filterid: The filter ID
:param mov_stack: The movement stack
:return: The extracted data
This method retrieves the DTT data from a NetCDF file based on the given inputs. It constructs the file path using the provided parameters and checks if the file exists. If the file
* does not exist, it raises a FileNotFoundError. Otherwise, it opens the NetCDF file and extracts the DTT variable as a dataframe. The dataframe is then rearranged and returned as the
* result.
"""
fn = os.path.join("DTT2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s" % components,
Expand Down Expand Up @@ -2360,6 +2394,15 @@ def xr_get_dvv(components, filterid, mov_stack):


def wavg(group, dttname, errname):
"""
Calculate the weighted average of a given group using the provided parameters.
:param group: A pandas DataFrame or Series representing the group of data.
:param dttname: The name of the column containing the data to be averaged.
:param errname: The name of the column containing the error for each data point.
:return: The weighted average of the data.
"""
d = group[dttname]
group[errname][group[errname] == 0] = 1e-6
w = 1. / group[errname]
Expand All @@ -2371,6 +2414,25 @@ def wavg(group, dttname, errname):


def wstd(group, dttname, errname):
"""
:param group: A dictionary containing data for different groups.
:param dttname: The key in the `group` dictionary that corresponds to the data array.
:param errname: The key in the `group` dictionary that corresponds to the error array.
:return: The weighted standard deviation of the data array.
This method calculates the weighted standard deviation of the data array specified by `dttname` in the `group` dictionary.
The weights are derived from the error array specified by `errname` in the `group` dictionary.
The weighted standard deviation is computed using the following formula:
wstd = sqrt(sum(w * (d - wavg) ** 2) / ((N - 1) * sum(w) / N))
where:
- d is the data array specified by `dttname` in the `group` dictionary.
- w is the weight array derived from the error array specified by `errname` in the `group` dictionary.
- wavg is the weighted average of the data array.
- N is the number of non-zero weights.
Note: This method uses the `np` module from NumPy.
"""
d = group[dttname]
group[errname][group[errname] == 0] = 1e-6
w = 1. / group[errname]
Expand All @@ -2381,13 +2443,36 @@ def wstd(group, dttname, errname):


def get_wavgwstd(data, dttname, errname):
"""
Calculate the weighted average and weighted standard deviation for a given data.
:param data: The data to calculate the weighted average and weighted standard deviation.
:type data: pandas.DataFrame
:param dttname: The name of the column in the data frame containing the weights for the weighted average and weighted standard deviation calculation.
:type dttname: str
:param errname: The name of the column in the data frame containing the errors on the data.
:type errname: str
:return: A tuple containing the calculated weighted average and weighted standard deviation.
:rtype: tuple
"""
grouped = data.groupby(level=0)
g = grouped.apply(wavg, dttname=dttname, errname=errname)
h = grouped.apply(wstd, dttname=dttname, errname=errname)
return g, h


def trim(data, dttname, errname, limits=0.1):
def trim(data, dttname, limits=0.1):
"""
Trimmed mean and standard deviation calculation.
:param data: DataFrame containing the data.
:param dttname: Name of the column used for grouping.
:param limits: Trimming limits (default is 0.1).
:return: Tuple containing the trimmed mean and trimmed standard deviation.
"""
from scipy.stats.mstats import trimmed_mean, trimmed_std
grouped = data[dttname].groupby(level=0)
if limits == 0:
Expand Down

0 comments on commit 149a781

Please sign in to comment.