Skip to content

Commit

Permalink
massive XArray workflow transform
Browse files Browse the repository at this point in the history
/!\ TESTS WON'T PASS !
  • Loading branch information
ThomasLecocq committed May 25, 2023
1 parent 450e4e6 commit 5ffd050
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 380 deletions.
244 changes: 241 additions & 3 deletions msnoise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,9 @@ def massive_update_job(session, jobs, flag="D"):
Routine to use a low level function to update much faster a list of
:class:`~msnoise.msnoise_table_def.declare_tables.Job`. This method uses the Job.ref
which is unique.
:type jobs: list
:type session: Session
:param session: the database connection object
:type jobs: list or tuple
:param jobs: a list of :class:`~msnoise.msnoise_table_def.declare_tables.Job` to update.
:type flag: str
:param flag: The destination flag.
Expand Down Expand Up @@ -927,6 +928,23 @@ def get_next_job(session, flag='T', jobtype='CC', limit=99999):
return tmp


def get_dvv_jobs(session, flag='T', jobtype='DVV', limit=99999):
from sqlalchemy import update
tmp = []
while not len(tmp):
jobs = session.query(Job).filter(Job.jobtype == jobtype). \
filter(Job.flag == flag). \
limit(limit).with_for_update()
# print(jobs.statement.compile(compile_kwargs={"literal_binds": True}))
tmp = jobs.all()
refs = [_.ref for _ in tmp]
q = update(Job).values({"flag": "I"}).where(Job.ref.in_(refs))
session.execute(q)
# jobs.update({Job.flag: 'I'})
session.commit()
return tmp


def is_dtt_next_job(session, flag='T', jobtype='DTT', ref=False):
"""
Are there any DTT :class:`~msnoise.msnoise_table_def.declare_tables.Job` in the database,
Expand Down Expand Up @@ -1444,7 +1462,7 @@ def get_results(session, station1, station2, filterid, components, dates,
if format == "matrix":
return i, stack_data

if format == "dataframe":
elif format == "dataframe":
taxis = get_t_axis(session)
return pd.DataFrame(stack_data, index=pd.DatetimeIndex(dates),
columns=taxis).loc[:lastday]
Expand All @@ -1467,6 +1485,8 @@ def get_results(session, station1, station2, filterid, components, dates,
else:
return 0, None



def get_mwcs(session, station1, station2, filterid, components, date,
mov_stack=1):
"""
Expand Down Expand Up @@ -2090,6 +2110,11 @@ def xr_create_or_open(fn, taxis=[], name="CCF"):
data = np.random.random((len(times), len(taxis)))
dr = xr.DataArray(data, coords=[times, taxis], dims=["times", "taxis"])
dr.name = name
elif name == "REF":
times = pd.date_range("2000-01-01", freq="H", periods=0)
data = np.random.random(len(taxis))
dr = xr.DataArray(data, coords=[taxis], dims=["taxis"])
dr.name = name
elif name == "MWCS":
times = pd.date_range("2000-01-01", freq="H", periods=0)
keys = ["M", "EM", "MCOH"]
Expand All @@ -2104,13 +2129,24 @@ def xr_create_or_open(fn, taxis=[], name="CCF"):
dr = xr.DataArray(data, coords=[times, keys],
dims=["times", "keys"])
dr.name = name
elif name == "DVV":
times = pd.date_range("2000-01-01", freq="H", periods=0)
level0 = ["m", "em", "a", "ea", "m0", "em0"]
level1 = ['10%', '25%', '5%', '50%', '75%', '90%', '95%', 'count', 'max', 'mean',
'min', 'std', 'trimmed_mean', 'trimmed_std', 'weighted_mean', 'weighted_std']
data = np.random.random((len(times), len(level0), len(level1)))
dr = xr.DataArray(data, coords=[times, level0, level1],
dims=["times", "level0", "level1"])
dr.name = name
else:
print("Not implemented, name=%s invalid." % name)
sys.exit(1)
return dr.to_dataset()


def xr_insert_or_update(dataset, new):
print("dataset", type(dataset))
print("new", type(new))
tt = new.merge(dataset, compat='override', combine_attrs="drop_conflicts")
return tt.combine_first(dataset)

Expand All @@ -2121,3 +2157,205 @@ def xr_save_and_close(dataset, fn):
dataset.to_netcdf(fn, mode="w")
dataset.close()
del dataset


def xr_get_ccf(station1, station2, components, filterid, mov_stack, taxis):
path = os.path.join("STACKS2", "%02i" % filterid,
"%03i_DAYS" % mov_stack, "%s" % components)
fn = "%s_%s.nc" % (station1, station2)

fullpath = os.path.join(path, fn)
if not os.path.isfile(fullpath):
print("FILE DOES NOT EXIST: %s, skipping" % fullpath)
raise FileNotFoundError
data = xr_create_or_open(fullpath, taxis)
return data.CCF.to_dataframe().unstack().droplevel(0, axis=1)


def xr_get_ref(station1, station2, components, filterid, taxis):
path = os.path.join("STACKS2", "%02i" % filterid,
"REF", "%s" % components)
fn = "%s_%s.nc" % (station1, station2)

fullpath = os.path.join(path, fn)
if not os.path.isfile(fullpath):
print("FILE DOES NOT EXIST: %s, skipping" % fullpath)
raise FileNotFoundError
data = xr_create_or_open(fullpath, taxis, name="REF")
return data.CCF.to_dataframe()


def xr_save_mwcs(station1, station2, components, filterid, mov_stack, taxis, dataframe):
fn = os.path.join("MWCS2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s" % components,
"%s_%s.nc" % (station1, station2))
if not os.path.isdir(os.path.split(fn)[0]):
os.makedirs(os.path.split(fn)[0])
d = dataframe.stack().stack()
d.index = d.index.set_names(["times", "keys", "taxis"])
d = d.reorder_levels(["times", "taxis", "keys"])
d.columns = ["MWCS"]
taxis = np.unique(d.index.get_level_values('taxis'))
dr = xr_create_or_open(fn, taxis=taxis, name="MWCS")
rr = d.to_xarray().to_dataset(name="MWCS")
rr = xr_insert_or_update(dr, rr)
xr_save_and_close(rr, fn)
del dr, rr, d


def xr_get_mwcs(station1, station2, components, filterid, mov_stack):
fn = os.path.join("MWCS2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s" % components,
"%s_%s.nc" % (station1, station2))
if not os.path.isfile(fn):
print("FILE DOES NOT EXIST: %s, skipping" % fn)
raise FileNotFoundError
data = xr_create_or_open(fn)
data = data.MWCS.to_dataframe().reorder_levels(['times', 'taxis', 'keys']).unstack().droplevel(0, axis=1).unstack()
return data


def xr_save_dtt(station1, station2, components, filterid, mov_stack, dataframe):
fn = os.path.join("DTT2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s" % components,
"%s_%s.nc" % (station1, station2))
if not os.path.isdir(os.path.split(fn)[0]):
os.makedirs(os.path.split(fn)[0])
d = dataframe.stack()
print("OUTPUT:")
print(d.head())
d.index = d.index.set_names(["times", "keys"])
d.columns = ["DTT"]
dr = xr_create_or_open(fn, taxis=[], name="DTT")
rr = d.to_xarray().to_dataset(name="DTT")
rr = xr_insert_or_update(dr, rr)
xr_save_and_close(rr, fn)


def xr_get_dtt(station1, station2, components, filterid, mov_stack):
fn = os.path.join("DTT2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s" % components,
"%s_%s.nc" % (station1, station2))
if not os.path.isfile(fn):
print("FILE DOES NOT EXIST: %s, skipping" % fn)
raise FileNotFoundError
dr = xr_create_or_open(fn, taxis=[], name="DTT")
data = dr.DTT.to_dataframe().reorder_levels(['times', 'keys']).unstack().droplevel(0, axis=1)
return data


def xr_save_dvv(components, filterid, mov_stack, dataframe):
fn = os.path.join("DVV2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s.nc" % components)
if not os.path.isdir(os.path.split(fn)[0]):
os.makedirs(os.path.split(fn)[0])
d = dataframe.stack().stack()
d.index = d.index.set_names(["times", "level1", "level0"])
d = d.reorder_levels(["times", "level0", "level1"])
d.columns = ["DVV"]
# taxis = np.unique(d.index.get_level_values('taxis'))
dr = xr_create_or_open(fn, taxis=[], name="DVV")
rr = d.to_xarray().to_dataset(name="DVV")
rr = xr_insert_or_update(dr, rr)
xr_save_and_close(rr, fn)
del dr, rr, d


def xr_get_dvv(components, filterid, mov_stack):
fn = os.path.join("DVV2", "%02i" % filterid,
"%03i_DAYS" % mov_stack,
"%s.nc" % components)
if not os.path.isfile(fn):
print("FILE DOES NOT EXIST: %s, skipping" % fn)
raise FileNotFoundError
data = xr_create_or_open(fn)
data = data.DVV.to_dataframe().reorder_levels(['times', 'level1', 'level0']).unstack().droplevel(0, axis=1).unstack()
return data


def wavg(group, dttname, errname):
d = group[dttname]
group[errname][group[errname] == 0] = 1e-6
w = 1. / group[errname]
try:
wavg = (d * w).sum() / w.sum()
except:
wavg = d.mean()
return wavg


def wstd(group, dttname, errname):
d = group[dttname]
group[errname][group[errname] == 0] = 1e-6
w = 1. / group[errname]
wavg = (d * w).sum() / w.sum()
N = len(np.nonzero(w)[0])
wstd = np.sqrt(np.sum(w * (d - wavg) ** 2) / ((N - 1) * np.sum(w) / N))
return wstd


def get_wavgwstd(data, dttname, errname):
grouped = data.groupby(level=0)
g = grouped.apply(wavg, dttname=dttname, errname=errname)
h = grouped.apply(wstd, dttname=dttname, errname=errname)
return g, h


def trim(data, dttname, errname, limits=0.1):
from scipy.stats.mstats import trimmed_mean, trimmed_std
grouped = data[dttname].groupby(level=0)
if limits == 0:
g = grouped.mean()
h = grouped.std()
else:
g = grouped.apply(trimmed_mean, limits=limits)
h = grouped.apply(trimmed_std, limits=limits)
return g, h


def compute_dvv(session, filterid, mov_stack, pairs=None, components=None, params=None, method=None, **kwargs):
if pairs == None:
pairs = []
for sta1, sta2 in get_station_pairs(session):
for loc1 in sta1.locs():
s1 = "%s.%s.%s" % (sta1.net, sta1.sta, loc1)
for loc2 in sta2.locs():
s2 = "%s.%s.%s" % (sta2.net, sta2.sta, loc2)
pairs.append((s1, s2))
all = []
for (s1, s2) in pairs:
if components == None:
if s1 == s2:
# if not provided, we'll load all:
comps = params.components_to_compute_single_station
else:
comps = params.components_to_compute
else:
if components.count(',') == 0:
comps = [components, ]
else:
comps = components.split(',')

for comp in comps:
try:
dtt = xr_get_dtt(s1, s2, comp, filterid, mov_stack)
all.append(dtt)
except FileNotFoundError:
continue
if not len(all):
raise ValueError
if len(all) == 1:
return all[0]
all = pd.concat(all)
percentiles = kwargs.get("percentiles", [.05, .10, .25, .5, .75, .90, .95])
stats = all.groupby(level=0).describe(percentiles=percentiles)
for c in ["m", "m0", "a"]:
stats[(c, "weighted_mean")], stats[(c, "weighted_std")] = get_wavgwstd(all, c, 'e'+c)
stats[(c, "trimmed_mean")], stats[(c, "trimmed_std")] = trim(all, c, 'e'+c, kwargs.get("limits", None))

return stats.sort_index(axis=1)
18 changes: 11 additions & 7 deletions msnoise/plots/ccftime.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def main(sta1, sta2, filterid, components, mov_stack=1, ampli=5, seismic=False,
sta1 = sta1 #.replace('.', '_')
sta2 = sta2 #.replace('.', '_')
t = np.arange(samples)/cc_sampling_rate - maxlag
taxis = get_t_axis(db)

if refilter:
freqmin, freqmax = refilter.split(':')
Expand All @@ -75,16 +76,19 @@ def main(sta1, sta2, filterid, components, mov_stack=1, ampli=5, seismic=False,

print("Fetching CCF data for %s-%s-%i-%i" % (pair, components, filterid,
mov_stack))
nstack, stack_total = get_results(db, sta1, sta2, filterid, components,
datelist, mov_stack, format="matrix")
if nstack == 0:
stack_total = xr_get_ccf(sta1, sta2, components, filterid, mov_stack, taxis)

# convert index to mdates
stack_total.index = mdates.date2num(stack_total.index.to_pydatetime())

if len(stack_total) == 0:
print("No CCF found for this request")
return

if normalize == "common":
stack_total /= np.nanmax(stack_total)
ax = plt.subplot(111)
for i, line in enumerate(stack_total):
for i, line in stack_total.iterrows():
if np.all(np.isnan(line)):
continue
if refilter:
Expand All @@ -94,10 +98,10 @@ def main(sta1, sta2, filterid, components, mov_stack=1, ampli=5, seismic=False,
line = obspy_envelope(line)
if normalize == "individual":
line /= line.max()
plt.plot(t, line * ampli + i + base, c='k', lw=0.5)
plt.plot(t, line * ampli + i, c='k', lw=0.5)
if seismic:
y1 = np.ones(len(line)) * i + base
y2 = line*ampli + i + base
y1 = np.ones(len(line)) * i
y2 = line*ampli + i
plt.fill_between(t, y1, y2, where=y2 >= y1, facecolor='k',
interpolate=True)
low = high = 0.0
Expand Down

0 comments on commit 5ffd050

Please sign in to comment.