API#

earth2mip.diagnostics module#

class earth2mip.diagnostics.Diagnostics(group, domain, grid, diagnostic, lat, lon, device)#

Bases: object

Parameters:
get_dimensions()#
get_dtype()#
get_variables()#
update()#
class earth2mip.diagnostics.Raw(group, domain, grid, diagnostic, lat, lon, device)#

Bases: Diagnostics

Parameters:
get_dimensions()#
get_dtype()#
update(output, time_index, batch_id, batch_size)#
Parameters:
  • output (Tensor) –

  • time_index (int) –

  • batch_id (int) –

  • batch_size (int) –

earth2mip.ensemble_utils module#

class earth2mip.ensemble_utils.GaussianRandomFieldS2(nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid='equiangular', dtype=torch.float32)#

Bases: Module

cuda(*args, **kwargs)#

Moves all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Parameters:

device (int, optional) – if specified, all parameters will be copied to that device

Returns:

self

Return type:

Module

forward(N, xi=None)#

Sample random functions from a spherical GRF.

Parameters:
  • N (int) – Number of functions to sample.

  • xi (torch.Tensor, default is None) – Noise is a complex tensor of size (N, nlat, nlat+1). If None, new Gaussian noise is sampled. If xi is provided, N is ignored.

  • Output

  • -------

  • u (torch.Tensor) – N random samples from the GRF returned as a tensor of size (N, nlat, 2*nlat) on a equiangular grid.

to(*args, **kwargs)#

Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Parameters:
  • device (torch.device) – the desired device of the parameters and buffers in this module

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)

Returns:

self

Return type:

Module

Examples:

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
earth2mip.ensemble_utils.brown_noise(shape, reddening=2)#
earth2mip.ensemble_utils.generate_bred_vector(x, model, noise_amplitude, time=None, integration_steps=40, inflate=False)#
Parameters:
  • x (Tensor) –

  • model (TimeLoop) –

  • noise_amplitude (Tensor) –

  • time (datetime | None) –

  • integration_steps (int) –

Return type:

Tensor

earth2mip.ensemble_utils.generate_noise_correlated(shape, *, reddening, device, noise_amplitude)#
earth2mip.ensemble_utils.generate_noise_grf(shape, grid, alpha, sigma, tau, device=None)#

earth2mip.filesystem module#

earth2mip.filesystem.download_cached(path, recursive=False)#
Parameters:
  • path (str) –

  • recursive (bool) –

Return type:

str

earth2mip.filesystem.glob(pattern, maxdepth=1)#
Parameters:

pattern (str) –

Return type:

List[str]

earth2mip.filesystem.ls(path)#
earth2mip.filesystem.open(path, mode='r')#
earth2mip.filesystem.pipe(dest, value)#

Save string to dest

earth2mip.forecast_metrics_io module#

Routines for reading and writing forecast metrics to a directory of csv files.

The csv files contain the records:

initial_time_iso, lead_time_hours, channel, metric, value
2022-01-01T00:00:00,24,t2m,rmse,25.6
earth2mip.forecast_metrics_io.read_metrics(directory)#

Reads all csv files in the given directory and returns a pandas Series containing all the metric values.

Parameters:

directory (str) –

Return type:

Series

earth2mip.forecast_metrics_io.write_metric(f, initial_time, lead_time, channel, metric, value)#

Writes a single metric value to the given file object in csv format.

Parameters:
  • f (IO[str]) –

  • initial_time (datetime) –

  • lead_time (timedelta) –

  • channel (str) –

  • metric (str) –

  • value (float) –

Return type:

None

earth2mip.forecasts module#

Forecast abstractions

A forecast is a discrete array of (n_initial_times, n_lead_times). However because a forecast evolves forward in time, and we do not store the whole forecast necessarily, algorithms in fcn-mip should access n_lead_times in sequential order. This is the purpose of the abstractions here.

class earth2mip.forecasts.Forecast(*args, **kwargs)#

Bases: Protocol

property channel_names: List[str]#
property grid: LatLonGrid#
class earth2mip.forecasts.Persistence(observations)#

Bases: object

persistence forecast. This forecast always returns the initial condition.

Yields (channel, lat, lon)

Parameters:

observations (Any) –

property channel_names#
class earth2mip.forecasts.TimeLoopForecast(time_loop, times, data_source)#

Bases: Forecast

Wrap an fcn-mip TimeLoop object as a forecast

Parameters:
property channel_names#
property grid#
class earth2mip.forecasts.XarrayForecast(ds, fields, times, device)#

Bases: Forecast

Turn an xarray into a forecast-like dataset

Parameters:
  • ds (Dataset) –

  • times (Sequence[datetime]) –

property channel_names#
property grid#
class earth2mip.forecasts.select_channels(forecast, channel_names)#

Bases: Forecast

Parameters:
  • forecast (Forecast) –

  • channel_names (list[str]) –

property channel_names#
property grid#

earth2mip.geo_operator module#

class earth2mip.geo_operator.GeoOperator(*args, **kwargs)#

Bases: Protocol

Geo Operator

This is the most primative functional of Earth-2 MIP which represents a operators on geospatial fields. This implies the following two requirements:

  1. The operation must define in and out channel variables representing the

    fields in the input/output arrays.

  2. The operation must define the in and out grid schemas.

Many auto-gressive models can be represented as a GeoOperator and can maintain a internal state. Diagnostic models must be a GeoOperator by definition.

Warning

Geo Function is a concept not full adopted in Earth-2 MIP and is being adopted progressively.

property in_channel_names: list[str]#
property in_grid: LatLonGrid#
property out_channel_names: list[str]#
property out_grid: LatLonGrid#

earth2mip.geometry module#

Routines for working with geometry

earth2mip.geometry.bilinear(data, dims, source_coords, target_coords)#
Parameters:

data (tensor) –

earth2mip.geometry.get_batch_size(data)#
earth2mip.geometry.get_bounds_window(geom, lat, lon)#
earth2mip.geometry.select_space(data, lat, lon, domain)#

earth2mip.grid module#

class earth2mip.grid.LatLonGrid(lat: List[float], lon: List[float])#

Bases: object

Parameters:
  • lat (List[float]) –

  • lon (List[float]) –

lat: List[float]#
lon: List[float]#
property shape#
earth2mip.grid.equiangular_lat_lon_grid(nlat, nlon, includes_south_pole=True)#

A regular lat-lon grid

Lat is ordered from 90 to -90. Includes -90 and only if if includes_south_pole is True. Lon is ordered from 0 to 360. includes 0, but not 360.

Parameters:
  • nlat (int) –

  • nlon (int) –

  • includes_south_pole (bool) –

Return type:

LatLonGrid

earth2mip.grid.from_enum(grid_enum)#
Parameters:

grid_enum (Grid) –

Return type:

LatLonGrid

earth2mip.inference_ensemble module#

earth2mip.inference_ensemble.run_basic_inference(model, n, data_source, time)#

Run a basic inference

Parameters:
  • model (TimeLoop) –

  • n (int) –

  • data_source (Any) –

  • time (datetime) –

earth2mip.inference_ensemble.run_inference(model, config, perturb=None, group=None, progress=True, data_source=None)#

Run an ensemble inference for a given config and a perturb function

Parameters:
  • group (Any | None) – the torch distributed group to use for the calculation

  • progress (bool) – if True use tqdm to show a progress bar

  • data_source (Any | None) – a Mapping object indexed by datetime and returning an xarray.Dataset object.

  • model (TimeLoop) –

  • config (EnsembleRun) –

  • perturb (Any | None) –

earth2mip.inference_medium_range module#

earth2mip.inference_medium_range.score_deterministic(model, n, initial_times, data_source, time_mean)#

Compute deterministic accs and rmses

Parameters:
  • model (TimeLoop) – the inference class

  • n (int) – the number of lead times

  • initial_times – the initial_times to compute over

  • data_source – a mapping from time to dataset, used for the initial condition and the scoring

  • time_mean – a (channel, lat, lon) numpy array containing the time_mean. Used for ACC.

Returns:

an xarray dataset wtih this structure::

netcdf dlwp.baseline { dimensions:

lead_time = 57 ; channel = 7 ; initial_time = 1 ;

variables:
int64 lead_time(lead_time) ;

lead_time:units = “hours” ;

string channel(channel) ; double acc(lead_time, channel) ;

acc:_FillValue = NaN ;

double rmse(lead_time, channel) ;

rmse:_FillValue = NaN ;

int64 initial_times(initial_time) ;

initial_times:units = “days since 2018-11-30 12:00:00” ; initial_times:calendar = “proleptic_gregorian” ;

}

Return type:

metrics

earth2mip.loaders module#

class earth2mip.loaders.LoaderProtocol(*args, **kwargs)#

Bases: Protocol

earth2mip.loaders.torchscript(package, pretrained=True)#

load a checkpoint into a model

earth2mip.make_job module#

earth2mip.make_job.get_time(times)#
earth2mip.make_job.get_time_s2s_calibration()#
earth2mip.make_job.get_times_2018()#
earth2mip.make_job.get_times_s2s_test()#
earth2mip.make_job.main(model, config, output)#
Parameters:
  • model (str) –

  • config (str) –

  • output (str) –

earth2mip.model_registry module#

Create-read-update-delete (CRUD) operations for the FCN model registry

The location of the registry is configured using config.MODEL_REGISTRY. Both s3:// and local paths are supported.

The top-level structure of the registry is like this:

afno_26ch_v/
baseline_afno_26/
gfno_26ch_sc3_layers8_tt64/
hafno_baseline_26ch_edim512_mlp2/
modulus_afno_20/
sfno_73ch/
tfno_no-patching_lr5e-4_full_epochs/

The name of the model is the folder name. Each of these folders has the following structure:

sfno_73ch/about.txt            # optional information (e.g. source path)
sfno_73ch/global_means.npy
sfno_73ch/global_stds.npy
sfno_73ch/weights.tar          # model checkpoint
sfno_73ch/metadata.json

The metadata.json file contains data necessary to use the model for forecasts:

{
    "architecture": "sfno_73ch",
    "n_history": 0,
    "grid": "721x1440",
    "in_channels": [
        0,
        1
    ],
    "out_channels": [
        0,
        1
    ]
}

Its schema is provided by the earth2mip.schema.Model.

The checkpoint file weights.tar should have a dictionary of model weights and parameters in the model_state key. For backwards compatibility with FCN checkpoints produced as of March 1, 2023 the keys should include prefixed module. prefix. This checkpoint format may change in the future.

Scoring FCNs under active development#

One can use fcn-mip to score models not packaged in fcn-mip using a metadata file like this:

{
    "architecture": "pickle",
    ...
}

This will load weights.tar using torch.load. This is not recommended for long-time archival of model checkpoints but does allow scoring models under active development. Once a reasonable skill is achieved the model’s source code can be stabilized and packaged within fcn-mip for long-term archival.

earth2mip.model_registry.DLWPPackage(root, seperator)#
Parameters:
  • root (str) –

  • seperator (str) –

earth2mip.model_registry.FCNPackage(root, seperator)#
Parameters:
  • root (str) –

  • seperator (str) –

earth2mip.model_registry.FCNv2Package(root, seperator)#
Parameters:
  • root (str) –

  • seperator (str) –

class earth2mip.model_registry.ModelRegistry(path)#

Bases: object

Parameters:

path (str) –

SEPERATOR: str = '/'#
get_builtin_model(name)#

Built in models that have globally buildable packages

Parameters:

name (str) –

get_center_path(name)#
Parameters:

name (str) –

get_metadata(name)#
Parameters:

name (str) –

Return type:

Model

get_model(name)#
Parameters:

name (str) –

get_model_path(name)#
Parameters:

name (str) –

get_path(name, *args)#
get_scale_path(name)#
Parameters:

name (str) –

get_weight_path(name)#
Parameters:

name (str) –

list_models()#
put_metadata(name, metadata)#
Parameters:
  • name (str) –

  • metadata (Model) –

earth2mip.model_registry.NGCDiagnosticPackage(root, seperator, name)#
Parameters:
  • root (str) –

  • seperator (str) –

  • name (Literal['precipitation_afno', 'climatenet']) –

class earth2mip.model_registry.Package(root, seperator)#

Bases: object

A model package

Simple file system operations and quick metadata access

Parameters:
  • root (str) –

  • seperator (str) –

get(path, recursive=False)#
Parameters:

recursive (bool) –

metadata()#
Return type:

Model

earth2mip.model_registry.PanguPackage(root, seperator)#
Parameters:
  • root (str) –

  • seperator (str) –

earth2mip.model_registry.download_ngc_package(root, url, zip_file)#
Parameters:
  • root (str) –

  • url (str) –

  • zip_file (str) –

earth2mip.netcdf module#

Routines to save domains to a netCDF file

earth2mip.netcdf.initialize_netcdf(nc, domains, grid, n_ensemble, device)#
Parameters:
Return type:

List[List[Diagnostics]]

earth2mip.netcdf.update_netcdf(data, total_diagnostics, domains, batch_id, time_count, grid, channel_names_of_data)#
Parameters:

earth2mip.regrid module#

class earth2mip.regrid.Identity(*args, **kwargs)#

Bases: Module

forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.regrid.RegridLatLon(src_grid, dest_grid)#

Bases: Module

Parameters:
forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.regrid.TempestRegridder(file_path)#

Bases: Module

forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

to(device)#

Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Parameters:
  • device (torch.device) – the desired device of the parameters and buffers in this module

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)

Returns:

self

Return type:

Module

Examples:

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
earth2mip.regrid.get_regridder(src, dest)#
Parameters:
Return type:

Module

earth2mip.schema module#

class earth2mip.schema.EnsembleRun(*, weather_model, simulation_length, perturbation_strategy=PerturbationStrategy.correlated, perturbation_channels=None, noise_reddening=2.0, noise_amplitude=0.05, output_frequency=1, output_grid=None, ensemble_members=1, seed=1, ensemble_batch_size=1, forecast_name=None, weather_event=None, output_dir=None, output_path=None, restart_frequency=None, grf_noise_alpha=2.0, grf_noise_sigma=5.0, grf_noise_tau=2.0)#

Bases: BaseModel

A configuration for running an ensemble weather forecast

Parameters:
  • weather_model (str) –

  • simulation_length (int) –

  • perturbation_strategy (PerturbationStrategy) –

  • perturbation_channels (List[str] | None) –

  • noise_reddening (float) –

  • noise_amplitude (float) –

  • output_frequency (int) –

  • output_grid (Grid | None) –

  • ensemble_members (int) –

  • seed (int) –

  • ensemble_batch_size (int) –

  • forecast_name (str | None) –

  • weather_event (WeatherEvent | None) –

  • output_dir (str | None) –

  • output_path (str | None) –

  • restart_frequency (int | None) –

  • grf_noise_alpha (float) –

  • grf_noise_sigma (float) –

  • grf_noise_tau (float) –

weather_model#

The name of the fully convolutional neural network (FCN) model to use for the forecast.

Type:

str

ensemble_members#

The number of ensemble members to use in the forecast.

Type:

int

noise_amplitude#

The amplitude of the Gaussian noise to add to the initial conditions.

Type:

float

noise_reddening#

The noise reddening amplitude, 2.0 was the defualt set by A.G. work.

Type:

float

simulation_length#

The length of the simulation in timesteps.

Type:

int

output_frequency#

The frequency at which to write the output to file, in timesteps.

Type:

int

use_cuda_graphs#

Whether to use CUDA graphs to optimize the computation.

seed#

The random seed for the simulation.

Type:

int

ensemble_batch_size#

The batch size to use for the ensemble.

Type:

int

autocast_fp16#

Whether to use automatic mixed precision (AMP) with FP16 data types.

perturbation_strategy#

The strategy to use for perturbing the initial conditions.

Type:

earth2mip.schema.PerturbationStrategy

perturbation_channels#

channel(s) perturbed by the initial condition perturbation strategy, None = all channels

Type:

List[str] | None

forecast_name#

The name of the forecast to use (alternative to weather_event).

Type:

optional

weather_event#

The weather event to use for the forecast (alternative to forecast_name).

Type:

optional

output_dir#

The directory to save the output files in (alternative to output_path).

Type:

optional

output_path#

The path to the output file (alternative to output_dir).

Type:

optional

restart_frequency#

if provided save at end and at the specified frequency. 0 = only save at end.

Type:

int | None

grf_noise_alpha#

tuning parameter of the Gaussian random field, see ensemble_utils.generate_noise_grf for details

Type:

float

grf_noise_sigma#

tuning parameter of the Gaussian random field, see ensemble_utils.generate_noise_grf for details

Type:

float

grf_noise_tau#

tuning parameter of the Gaussian random field, see ensemble_utils.generate_noise_grf for details

Type:

float

ensemble_batch_size: int#
ensemble_members: int#
forecast_name: str | None#
get_weather_event()#
Return type:

WeatherEvent

grf_noise_alpha: float#
grf_noise_sigma: float#
grf_noise_tau: float#
noise_amplitude: float#
noise_reddening: float#
output_dir: str | None#
output_frequency: int#
output_grid: Grid | None#
output_path: str | None#
perturbation_channels: List[str] | None#
perturbation_strategy: PerturbationStrategy#
restart_frequency: int | None#
seed: int#
simulation_length: int#
weather_event: WeatherEvent | None#
weather_model: str#
class earth2mip.schema.InferenceEntrypoint(*, name='', kwargs=None)#

Bases: BaseModel

Attrs:
name: an entrypoint string like my_package:model_entrypoint.

this points to a function model_entrypoint(package) which returns an Inference object given a package

kwargs: the arguments to pass to the constructor

Parameters:
  • name (str) –

  • kwargs (Mapping[Any, Any]) –

kwargs: Mapping[Any, Any]#
name: str#
class earth2mip.schema.InitialConditionSource(value)#

Bases: Enum

An enumeration.

cds: str = 'cds'#
era5: str = 'era5'#
gfs: str = 'gfs'#
hrmip: str = 'hrmip'#
ifs: str = 'ifs'#
class earth2mip.schema.PerturbationStrategy(value)#

Bases: Enum

An enumeration.

bred_vector = 'bred_vector'#
correlated = 'correlated'#
gaussian = 'gaussian'#
none = 'none'#
spherical_grf = 'spherical_grf'#
class earth2mip.schema.WeatherEvent(*, properties, domains)#

Bases: BaseModel

Parameters:
domains: List[Window | CWBDomain | MultiPoint]#
properties: WeatherEventProperties#

earth2mip.score_ensemble_outputs module#

earth2mip.score_ensemble_outputs.main(input_path, output_path=None, time_averaging_window='', score=True, save_ensemble=False)#
Parameters:
  • input_path (str) –

  • output_path (str | None) –

  • time_averaging_window (str) –

  • score (bool) –

  • save_ensemble (bool) –

Return type:

None

earth2mip.score_ensemble_outputs.open_ensemble(path, group)#
earth2mip.score_ensemble_outputs.open_verification(time)#
earth2mip.score_ensemble_outputs.read_weather_event(dir)#
earth2mip.score_ensemble_outputs.save_dataset(out, path)#

earth2mip.time module#

earth2mip.time.convert_to_datetime(time)#
Return type:

datetime

earth2mip.time.datetime_to_timestamp(time)#
Parameters:

time (datetime) –

Return type:

float

earth2mip.time_collection module#

earth2mip.time_collection.run_over_initial_times(*, time_loop, data_source, initial_times, config, output_path, time_averaging_window='', score=False, save_ensemble=False, shard=0, n_shards=1, n_post_processing_workers=32)#

Perform a set of forecasts across many initial conditions in parallel with post processing

Once complete, the data at output_path can be opened as an xarray object using earth2mip.datasets.hindcast.open_forecast().

Parallelizes across the available GPUs using MPI, and can be further parallelized across multiple MPI jobs using the shard/ n_shards flags. It can be resumed after interruption.

Parameters:
  • time_loop (TimeLoop) – the earth2mip TimeLoop to be evaluated. Often returned by earth2mip.networks.get_model

  • data_source (DataSource | None) – the data source used to initialize the time_loop, overrides any data source specified in config

  • initial_times (list[datetime]) – the initial times evaluated over

  • n_shards (int) – split the input times into this many shards

  • time_averaging_window (str) – if provided, average the output over this interval. Same syntax as pandas.Timedelta (e.g. “2w”). Default is no time averaging.

  • score (bool) – if true, score the times during the post processing

  • save_ensemble (bool) – if true, then save all the ensemble members in addition to the mean

  • shard (int) – index of the shard. useful for SLURM array jobs

  • n_shards – number of shards total.

  • n_post_processing_workers (int) – The number of dask distributed workers to devote to ensemble post processing.

  • config (EnsembleRun) –

  • output_path (str) –

Return type:

None

earth2mip.time_loop module#

class earth2mip.time_loop.GeoTensorInfo(channel_names, grid, n_history_levels=1, history_time_step=datetime.timedelta(0))#

Bases: object

Metadata explaining how tensor maps onto the Earth

Describes a tensor x with shape (batch, history, channel, lat, lon).

Parameters:
  • channel_names (List[str]) –

  • grid (LatLonGrid) –

  • n_history_levels (int) –

  • history_time_step (timedelta) –

channel_names: List[str]#
grid: LatLonGrid#
history_time_step: timedelta = datetime.timedelta(0)#
n_history_levels: int = 1#
class earth2mip.time_loop.TimeLoop(*args, **kwargs)#

Bases: Protocol

Abstract protocol that a custom time loop must follow

This is a callable which yields time and output information. Some attributes are required to define the input and output data required.

The expectation is that this class and the data passed to it are on the same device. While torch modules can be moved between devices easily, this is not true for all frameworks.

in_channel_names#
Type:

List[str]

out_channel_names#
Type:

List[str]

grid#
Type:

earth2mip.grid.LatLonGrid

n_history_levels#
Type:

int

history_time_step#
Type:

datetime.timedelta

time_step#
Type:

datetime.timedelta

device#
Type:

torch.device

device: device#
dtype: dtype = torch.float32#
grid: LatLonGrid#
history_time_step: timedelta = datetime.timedelta(0)#
in_channel_names: List[str]#
n_history_levels: int = 1#
out_channel_names: List[str]#
time_step: timedelta#
class earth2mip.time_loop.TimeStepper(*args, **kwargs)#

Bases: Protocol[StateT]

An functional interface that can be used for time stepping

state -> (state, output)

This uses a generic state, but concrete Tensors as input and output. This allows users to directly control the time-stepping logic and potentially modify the state in model-specific manner, but the basic initial conditions and running outputs are concrete torch Tensors.

One example is the graphcast time stepper. Graphcast uses jax and xarray to handle the state.

It should be used like this:

stepper = MyStepper()
state = stepper.initialize(x, time)

outputs = []
for i in range(10):
    state, output = stepper.step(state)
    outputs.append(output)

One benefit is that the state can be saved and reloaded trivially to restart the simulation.

property device: device#
property dtype: device#
initialize(x, time)#

x is described by self.input_info

Parameters:
  • x (Tensor) –

  • time (datetime) –

Return type:

StateT

property input_info: GeoTensorInfo#
property output_info: GeoTensorInfo#
step(state)#

step the state and return the ml output as a tensor

The output tensor is described by self.output_info

Parameters:

state (StateT) –

Return type:

tuple[StateT, Tensor]

property time_step: timedelta#
class earth2mip.time_loop.TimeStepperLoop(stepper)#

Bases: TimeLoop

Turn a TimeStepper into a TimeLoop

Parameters:

stepper (TimeStepper) –

property device: device#
property dtype: dtype#
property grid: LatLonGrid#
property history_time_step: timedelta#

Difference between two datetime values.

timedelta(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0)

All arguments are optional and default to 0. Arguments may be integers or floats, and may be positive or negative.

property in_channel_names: List[str]#
property n_history_levels: int#

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.__int__(). For floating point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by ‘+’ or ‘-’ and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal. >>> int(‘0b100’, base=0) 4

property out_channel_names: List[str]#
property time_step: timedelta#

earth2mip.weather_events module#

class earth2mip.weather_events.CWBDomain(*, type, name, path='/lustre/fsw/sw_climate_fno/nbrenowitz/2023-01-24-cwb-4years.zarr', diagnostics)#

Bases: BaseModel

Parameters:
  • type (Literal['CWBDomain']) –

  • name (str) –

  • path (str) –

  • diagnostics (List[Diagnostic]) –

diagnostics: List[Diagnostic]#
name: str#
path: str#
type: Literal['CWBDomain']#
class earth2mip.weather_events.Diagnostic(*, type, function='', channels, nbins=10)#

Bases: BaseModel

Parameters:
  • type (str) –

  • function (str) –

  • channels (List[str]) –

  • nbins (int) –

channels: List[str]#
function: str#
nbins: int#
type: str#
class earth2mip.weather_events.InitialConditionSource(value)#

Bases: Enum

An enumeration.

cds: str = 'cds'#
era5: str = 'era5'#
gfs: str = 'gfs'#
hrmip: str = 'hrmip'#
ifs: str = 'ifs'#
class earth2mip.weather_events.MultiPoint(*, type, name, lat, lon, diagnostics)#

Bases: BaseModel

Parameters:
  • type (Literal['MultiPoint']) –

  • name (str) –

  • lat (List[float]) –

  • lon (List[float]) –

  • diagnostics (List[Diagnostic]) –

diagnostics: List[Diagnostic]#
lat: List[float]#
lon: List[float]#
name: str#
type: Literal['MultiPoint']#
class earth2mip.weather_events.WeatherEvent(*, properties, domains)#

Bases: BaseModel

Parameters:
domains: List[Window | CWBDomain | MultiPoint]#
properties: WeatherEventProperties#
class earth2mip.weather_events.WeatherEventProperties(*, name, start_time=None, initial_condition_source=InitialConditionSource.era5, netcdf='', restart='')#

Bases: BaseModel

Parameters:
  • name (str) –

  • start_time (datetime | None) –

  • initial_condition_source (InitialConditionSource) –

  • netcdf (str) –

  • restart (str) –

netcdf#

load the initial conditions from this path if given

Type:

str

initial_condition_source: InitialConditionSource#
name: str#
netcdf: str#
restart: str#
start_time: datetime | None#
class earth2mip.weather_events.Window(*, type='Window', name, lat_min=-90, lat_max=90, lon_min=0, lon_max=360, diagnostics)#

Bases: BaseModel

Parameters:
  • type (Literal['Window']) –

  • name (str) –

  • lat_min (float) –

  • lat_max (float) –

  • lon_min (float) –

  • lon_max (float) –

  • diagnostics (List[Diagnostic]) –

diagnostics: List[Diagnostic]#
lat_max: float#
lat_min: float#
lon_max: float#
lon_min: float#
name: str#
type: Literal['Window']#
earth2mip.weather_events.list_()#
earth2mip.weather_events.read(forecast_name)#
Parameters:

forecast_name (str) –

Return type:

WeatherEvent