Inference Workflows#
These python functions work on the core objects of earth2mip.
Most are passed a earth2mip.initial_conditions.base.DataSource
and
an py:class:earth2mip.time_loop.TimeLoop object. These can be used from within
your own parallel scripts.
General Inference#
These routines combine many of the features of earth2mip, including ensemble initialization, post processing,
- earth2mip.inference_ensemble.run_inference(model, config, perturb=None, group=None, progress=True, data_source=None)#
Run an ensemble inference for a given config and a perturb function
- Parameters:
group (Any | None) – the torch distributed group to use for the calculation
progress (bool) – if True use tqdm to show a progress bar
data_source (Any | None) – a Mapping object indexed by datetime and returning an xarray.Dataset object.
model (TimeLoop) –
config (EnsembleRun) –
perturb (Any | None) –
- earth2mip.time_collection.run_over_initial_times(*, time_loop, data_source, initial_times, config, output_path, time_averaging_window='', score=False, save_ensemble=False, shard=0, n_shards=1, n_post_processing_workers=32)#
Perform a set of forecasts across many initial conditions in parallel with post processing
Once complete, the data at
output_path
can be opened as an xarray object usingearth2mip.datasets.hindcast.open_forecast()
.Parallelizes across the available GPUs using MPI, and can be further parallelized across multiple MPI jobs using the
shard
/n_shards
flags. It can be resumed after interruption.- Parameters:
time_loop (TimeLoop) – the earth2mip TimeLoop to be evaluated. Often returned by earth2mip.networks.get_model
data_source (DataSource | None) – the data source used to initialize the time_loop, overrides any data source specified in
config
initial_times (list[datetime]) – the initial times evaluated over
n_shards (int) – split the input times into this many shards
time_averaging_window (str) – if provided, average the output over this interval. Same syntax as pandas.Timedelta (e.g. “2w”). Default is no time averaging.
score (bool) – if true, score the times during the post processing
save_ensemble (bool) – if true, then save all the ensemble members in addition to the mean
shard (int) – index of the shard. useful for SLURM array jobs
n_shards – number of shards total.
n_post_processing_workers (int) – The number of dask distributed workers to devote to ensemble post processing.
config (EnsembleRun) –
output_path (str) –
- Return type:
None
Scoring#
- earth2mip.inference_medium_range.score_deterministic(model, n, initial_times, data_source, time_mean)#
Compute deterministic accs and rmses
- Parameters:
model (TimeLoop) – the inference class
n (int) – the number of lead times
initial_times – the initial_times to compute over
data_source – a mapping from time to dataset, used for the initial condition and the scoring
time_mean – a (channel, lat, lon) numpy array containing the time_mean. Used for ACC.
- Returns:
- an xarray dataset wtih this structure::
netcdf dlwp.baseline { dimensions:
lead_time = 57 ; channel = 7 ; initial_time = 1 ;
- variables:
- int64 lead_time(lead_time) ;
lead_time:units = “hours” ;
string channel(channel) ; double acc(lead_time, channel) ;
acc:_FillValue = NaN ;
- double rmse(lead_time, channel) ;
rmse:_FillValue = NaN ;
- int64 initial_times(initial_time) ;
initial_times:units = “days since 2018-11-30 12:00:00” ; initial_times:calendar = “proleptic_gregorian” ;
}
- Return type:
metrics