Inference Workflows#

These python functions work on the core objects of earth2mip. Most are passed a earth2mip.initial_conditions.base.DataSource and an py:class:earth2mip.time_loop.TimeLoop object. These can be used from within your own parallel scripts.

General Inference#

These routines combine many of the features of earth2mip, including ensemble initialization, post processing,

earth2mip.inference_ensemble.run_inference(model, config, perturb=None, group=None, progress=True, data_source=None)#

Run an ensemble inference for a given config and a perturb function

Parameters:
  • group (Any | None) – the torch distributed group to use for the calculation

  • progress (bool) – if True use tqdm to show a progress bar

  • data_source (Any | None) – a Mapping object indexed by datetime and returning an xarray.Dataset object.

  • model (TimeLoop) –

  • config (EnsembleRun) –

  • perturb (Any | None) –

earth2mip.time_collection.run_over_initial_times(*, time_loop, data_source, initial_times, config, output_path, time_averaging_window='', score=False, save_ensemble=False, shard=0, n_shards=1, n_post_processing_workers=32)#

Perform a set of forecasts across many initial conditions in parallel with post processing

Once complete, the data at output_path can be opened as an xarray object using earth2mip.datasets.hindcast.open_forecast().

Parallelizes across the available GPUs using MPI, and can be further parallelized across multiple MPI jobs using the shard/ n_shards flags. It can be resumed after interruption.

Parameters:
  • time_loop (TimeLoop) – the earth2mip TimeLoop to be evaluated. Often returned by earth2mip.networks.get_model

  • data_source (DataSource | None) – the data source used to initialize the time_loop, overrides any data source specified in config

  • initial_times (list[datetime]) – the initial times evaluated over

  • n_shards (int) – split the input times into this many shards

  • time_averaging_window (str) – if provided, average the output over this interval. Same syntax as pandas.Timedelta (e.g. “2w”). Default is no time averaging.

  • score (bool) – if true, score the times during the post processing

  • save_ensemble (bool) – if true, then save all the ensemble members in addition to the mean

  • shard (int) – index of the shard. useful for SLURM array jobs

  • n_shards – number of shards total.

  • n_post_processing_workers (int) – The number of dask distributed workers to devote to ensemble post processing.

  • config (EnsembleRun) –

  • output_path (str) –

Return type:

None

Scoring#

earth2mip.inference_medium_range.score_deterministic(model, n, initial_times, data_source, time_mean)#

Compute deterministic accs and rmses

Parameters:
  • model (TimeLoop) – the inference class

  • n (int) – the number of lead times

  • initial_times – the initial_times to compute over

  • data_source – a mapping from time to dataset, used for the initial condition and the scoring

  • time_mean – a (channel, lat, lon) numpy array containing the time_mean. Used for ACC.

Returns:

an xarray dataset wtih this structure::

netcdf dlwp.baseline { dimensions:

lead_time = 57 ; channel = 7 ; initial_time = 1 ;

variables:
int64 lead_time(lead_time) ;

lead_time:units = “hours” ;

string channel(channel) ; double acc(lead_time, channel) ;

acc:_FillValue = NaN ;

double rmse(lead_time, channel) ;

rmse:_FillValue = NaN ;

int64 initial_times(initial_time) ;

initial_times:units = “days since 2018-11-30 12:00:00” ; initial_times:calendar = “proleptic_gregorian” ;

}

Return type:

metrics