API#
earth2mip.diagnostics module#
- class earth2mip.diagnostics.Diagnostics(group, domain, grid, diagnostic, lat, lon, device)#
Bases:
object
- Parameters:
group (Group) –
domain (CWBDomain | Window | MultiPoint) –
grid (Grid) –
diagnostic (Diagnostic) –
lat (ndarray) –
lon (ndarray) –
device (device) –
- get_dimensions()#
- get_dtype()#
- get_variables()#
- update()#
- class earth2mip.diagnostics.Raw(group, domain, grid, diagnostic, lat, lon, device)#
Bases:
Diagnostics
- Parameters:
group (Group) –
domain (CWBDomain | Window | MultiPoint) –
grid (Grid) –
diagnostic (Diagnostic) –
lat (ndarray) –
lon (ndarray) –
device (device) –
- get_dimensions()#
- get_dtype()#
- update(output, time_index, batch_id, batch_size)#
- Parameters:
output (Tensor) –
time_index (int) –
batch_id (int) –
batch_size (int) –
earth2mip.ensemble_utils module#
- class earth2mip.ensemble_utils.GaussianRandomFieldS2(nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid='equiangular', dtype=torch.float32)#
Bases:
Module
- cuda(*args, **kwargs)#
Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.
Note
This method modifies the module in-place.
- Parameters:
device (int, optional) – if specified, all parameters will be copied to that device
- Returns:
self
- Return type:
Module
- forward(N, xi=None)#
Sample random functions from a spherical GRF.
- Parameters:
N (int) – Number of functions to sample.
xi (torch.Tensor, default is None) – Noise is a complex tensor of size (N, nlat, nlat+1). If None, new Gaussian noise is sampled. If xi is provided, N is ignored.
Output –
------- –
u (torch.Tensor) – N random samples from the GRF returned as a tensor of size (N, nlat, 2*nlat) on a equiangular grid.
- to(*args, **kwargs)#
Moves and/or casts the parameters and buffers.
This can be called as
- to(device=None, dtype=None, non_blocking=False)
- to(dtype, non_blocking=False)
- to(tensor, non_blocking=False)
- to(memory_format=torch.channels_last)
Its signature is similar to
torch.Tensor.to()
, but only accepts floating point or complexdtype
s. In addition, this method will only cast the floating point or complex parameters and buffers todtype
(if given). The integral parameters and buffers will be moveddevice
, if that is given, but with dtypes unchanged. Whennon_blocking
is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.See below for examples.
Note
This method modifies the module in-place.
- Parameters:
device (
torch.device
) – the desired device of the parameters and buffers in this moduledtype (
torch.dtype
) – the desired floating point or complex dtype of the parameters and buffers in this moduletensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module
memory_format (
torch.memory_format
) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)
- Returns:
self
- Return type:
Module
Examples:
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- earth2mip.ensemble_utils.brown_noise(shape, reddening=2)#
- earth2mip.ensemble_utils.generate_bred_vector(x, model, noise_amplitude, time=None, integration_steps=40, inflate=False)#
- Parameters:
x (Tensor) –
model (TimeLoop) –
noise_amplitude (Tensor) –
time (datetime | None) –
integration_steps (int) –
- Return type:
Tensor
- earth2mip.ensemble_utils.generate_noise_grf(shape, grid, alpha, sigma, tau, device=None)#
earth2mip.filesystem module#
- earth2mip.filesystem.download_cached(path, recursive=False)#
- Parameters:
path (str) –
recursive (bool) –
- Return type:
str
- earth2mip.filesystem.glob(pattern, maxdepth=1)#
- Parameters:
pattern (str) –
- Return type:
List[str]
- earth2mip.filesystem.ls(path)#
- earth2mip.filesystem.open(path, mode='r')#
- earth2mip.filesystem.pipe(dest, value)#
Save string to dest
earth2mip.forecast_metrics_io module#
Routines for reading and writing forecast metrics to a directory of csv files.
The csv files contain the records:
initial_time_iso, lead_time_hours, channel, metric, value
2022-01-01T00:00:00,24,t2m,rmse,25.6
- earth2mip.forecast_metrics_io.read_metrics(directory)#
Reads all csv files in the given directory and returns a pandas Series containing all the metric values.
- Parameters:
directory (str) –
- Return type:
Series
- earth2mip.forecast_metrics_io.write_metric(f, initial_time, lead_time, channel, metric, value)#
Writes a single metric value to the given file object in csv format.
- Parameters:
f (IO[str]) –
initial_time (datetime) –
lead_time (timedelta) –
channel (str) –
metric (str) –
value (float) –
- Return type:
None
earth2mip.forecasts module#
Forecast abstractions
A forecast is a discrete array of (n_initial_times, n_lead_times)
. However
because a forecast evolves forward in time, and we do not store the whole
forecast necessarily, algorithms in fcn-mip should access n_lead_times
in
sequential order. This is the purpose of the abstractions here.
- class earth2mip.forecasts.Forecast(*args, **kwargs)#
Bases:
Protocol
- property channel_names: List[str]#
- property grid: LatLonGrid#
- class earth2mip.forecasts.Persistence(observations)#
Bases:
object
persistence forecast. This forecast always returns the initial condition.
Yields (channel, lat, lon)
- Parameters:
observations (Any) –
- property channel_names#
- class earth2mip.forecasts.TimeLoopForecast(time_loop, times, data_source)#
Bases:
Forecast
Wrap an fcn-mip TimeLoop object as a forecast
- Parameters:
time_loop (TimeLoop) –
times (Sequence[datetime]) –
data_source (DataSource) –
- property channel_names#
- property grid#
earth2mip.geo_operator module#
- class earth2mip.geo_operator.GeoOperator(*args, **kwargs)#
Bases:
Protocol
Geo Operator
This is the most primative functional of Earth-2 MIP which represents a operators on geospatial fields. This implies the following two requirements:
- The operation must define in and out channel variables representing the
fields in the input/output arrays.
The operation must define the in and out grid schemas.
Many auto-gressive models can be represented as a GeoOperator and can maintain a internal state. Diagnostic models must be a GeoOperator by definition.
Warning
Geo Function is a concept not full adopted in Earth-2 MIP and is being adopted progressively.
- property in_channel_names: list[str]#
- property in_grid: LatLonGrid#
- property out_channel_names: list[str]#
- property out_grid: LatLonGrid#
earth2mip.geometry module#
Routines for working with geometry
- earth2mip.geometry.bilinear(data, dims, source_coords, target_coords)#
- Parameters:
data (tensor) –
- earth2mip.geometry.get_batch_size(data)#
- earth2mip.geometry.get_bounds_window(geom, lat, lon)#
- earth2mip.geometry.select_space(data, lat, lon, domain)#
earth2mip.grid module#
- class earth2mip.grid.LatLonGrid(lat: List[float], lon: List[float])#
Bases:
object
- Parameters:
lat (List[float]) –
lon (List[float]) –
- lat: List[float]#
- lon: List[float]#
- property shape#
- earth2mip.grid.equiangular_lat_lon_grid(nlat, nlon, includes_south_pole=True)#
A regular lat-lon grid
Lat is ordered from 90 to -90. Includes -90 and only if if includes_south_pole is True. Lon is ordered from 0 to 360. includes 0, but not 360.
- Parameters:
nlat (int) –
nlon (int) –
includes_south_pole (bool) –
- Return type:
- earth2mip.grid.from_enum(grid_enum)#
- Parameters:
grid_enum (Grid) –
- Return type:
earth2mip.inference_ensemble module#
- earth2mip.inference_ensemble.run_basic_inference(model, n, data_source, time)#
Run a basic inference
- Parameters:
model (TimeLoop) –
n (int) –
data_source (Any) –
time (datetime) –
- earth2mip.inference_ensemble.run_inference(model, config, perturb=None, group=None, progress=True, data_source=None)#
Run an ensemble inference for a given config and a perturb function
- Parameters:
group (Any | None) – the torch distributed group to use for the calculation
progress (bool) – if True use tqdm to show a progress bar
data_source (Any | None) – a Mapping object indexed by datetime and returning an xarray.Dataset object.
model (TimeLoop) –
config (EnsembleRun) –
perturb (Any | None) –
earth2mip.inference_medium_range module#
- earth2mip.inference_medium_range.score_deterministic(model, n, initial_times, data_source, time_mean)#
Compute deterministic accs and rmses
- Parameters:
model (TimeLoop) – the inference class
n (int) – the number of lead times
initial_times – the initial_times to compute over
data_source – a mapping from time to dataset, used for the initial condition and the scoring
time_mean – a (channel, lat, lon) numpy array containing the time_mean. Used for ACC.
- Returns:
- an xarray dataset wtih this structure::
netcdf dlwp.baseline { dimensions:
lead_time = 57 ; channel = 7 ; initial_time = 1 ;
- variables:
- int64 lead_time(lead_time) ;
lead_time:units = “hours” ;
string channel(channel) ; double acc(lead_time, channel) ;
acc:_FillValue = NaN ;
- double rmse(lead_time, channel) ;
rmse:_FillValue = NaN ;
- int64 initial_times(initial_time) ;
initial_times:units = “days since 2018-11-30 12:00:00” ; initial_times:calendar = “proleptic_gregorian” ;
}
- Return type:
metrics
earth2mip.loaders module#
- class earth2mip.loaders.LoaderProtocol(*args, **kwargs)#
Bases:
Protocol
- earth2mip.loaders.torchscript(package, pretrained=True)#
load a checkpoint into a model
earth2mip.make_job module#
- earth2mip.make_job.get_time(times)#
- earth2mip.make_job.get_time_s2s_calibration()#
- earth2mip.make_job.get_times_2018()#
- earth2mip.make_job.get_times_s2s_test()#
- earth2mip.make_job.main(model, config, output)#
- Parameters:
model (str) –
config (str) –
output (str) –
earth2mip.model_registry module#
Create-read-update-delete (CRUD) operations for the FCN model registry
The location of the registry is configured using config.MODEL_REGISTRY. Both s3:// and local paths are supported.
The top-level structure of the registry is like this:
afno_26ch_v/
baseline_afno_26/
gfno_26ch_sc3_layers8_tt64/
hafno_baseline_26ch_edim512_mlp2/
modulus_afno_20/
sfno_73ch/
tfno_no-patching_lr5e-4_full_epochs/
The name of the model is the folder name. Each of these folders has the following structure:
sfno_73ch/about.txt # optional information (e.g. source path)
sfno_73ch/global_means.npy
sfno_73ch/global_stds.npy
sfno_73ch/weights.tar # model checkpoint
sfno_73ch/metadata.json
The metadata.json file contains data necessary to use the model for forecasts:
{
"architecture": "sfno_73ch",
"n_history": 0,
"grid": "721x1440",
"in_channels": [
0,
1
],
"out_channels": [
0,
1
]
}
Its schema is provided by the earth2mip.schema.Model
.
The checkpoint file weights.tar should have a dictionary of model weights and parameters in the model_state key. For backwards compatibility with FCN checkpoints produced as of March 1, 2023 the keys should include prefixed module. prefix. This checkpoint format may change in the future.
Scoring FCNs under active development#
One can use fcn-mip to score models not packaged in fcn-mip using a metadata file like this:
{
"architecture": "pickle",
...
}
This will load weights.tar
using torch.load. This is not recommended for
long-time archival of model checkpoints but does allow scoring models under
active development. Once a reasonable skill is achieved the model’s source code
can be stabilized and packaged within fcn-mip for long-term archival.
- earth2mip.model_registry.DLWPPackage(root, seperator)#
- Parameters:
root (str) –
seperator (str) –
- earth2mip.model_registry.FCNPackage(root, seperator)#
- Parameters:
root (str) –
seperator (str) –
- earth2mip.model_registry.FCNv2Package(root, seperator)#
- Parameters:
root (str) –
seperator (str) –
- class earth2mip.model_registry.ModelRegistry(path)#
Bases:
object
- Parameters:
path (str) –
- SEPERATOR: str = '/'#
- get_builtin_model(name)#
Built in models that have globally buildable packages
- Parameters:
name (str) –
- get_center_path(name)#
- Parameters:
name (str) –
- get_metadata(name)#
- Parameters:
name (str) –
- Return type:
Model
- get_model(name)#
- Parameters:
name (str) –
- get_model_path(name)#
- Parameters:
name (str) –
- get_path(name, *args)#
- get_scale_path(name)#
- Parameters:
name (str) –
- get_weight_path(name)#
- Parameters:
name (str) –
- list_models()#
- put_metadata(name, metadata)#
- Parameters:
name (str) –
metadata (Model) –
- earth2mip.model_registry.NGCDiagnosticPackage(root, seperator, name)#
- Parameters:
root (str) –
seperator (str) –
name (Literal['precipitation_afno', 'climatenet']) –
- class earth2mip.model_registry.Package(root, seperator)#
Bases:
object
A model package
Simple file system operations and quick metadata access
- Parameters:
root (str) –
seperator (str) –
- get(path, recursive=False)#
- Parameters:
recursive (bool) –
- metadata()#
- Return type:
Model
- earth2mip.model_registry.PanguPackage(root, seperator)#
- Parameters:
root (str) –
seperator (str) –
- earth2mip.model_registry.download_ngc_package(root, url, zip_file)#
- Parameters:
root (str) –
url (str) –
zip_file (str) –
earth2mip.netcdf module#
Routines to save domains to a netCDF file
- earth2mip.netcdf.initialize_netcdf(nc, domains, grid, n_ensemble, device)#
- Parameters:
domains (Iterable[Window | CWBDomain | MultiPoint]) –
grid (LatLonGrid) –
- Return type:
List[List[Diagnostics]]
- earth2mip.netcdf.update_netcdf(data, total_diagnostics, domains, batch_id, time_count, grid, channel_names_of_data)#
- Parameters:
data (Tensor) –
total_diagnostics (List[List[Diagnostics]]) –
domains (List[Window | CWBDomain | MultiPoint]) –
grid (LatLonGrid) –
channel_names_of_data (List[str]) –
earth2mip.regrid module#
- class earth2mip.regrid.Identity(*args, **kwargs)#
Bases:
Module
- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.regrid.RegridLatLon(src_grid, dest_grid)#
Bases:
Module
- Parameters:
src_grid (LatLonGrid) –
dest_grid (LatLonGrid) –
- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.regrid.TempestRegridder(file_path)#
Bases:
Module
- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- to(device)#
Moves and/or casts the parameters and buffers.
This can be called as
- to(device=None, dtype=None, non_blocking=False)
- to(dtype, non_blocking=False)
- to(tensor, non_blocking=False)
- to(memory_format=torch.channels_last)
Its signature is similar to
torch.Tensor.to()
, but only accepts floating point or complexdtype
s. In addition, this method will only cast the floating point or complex parameters and buffers todtype
(if given). The integral parameters and buffers will be moveddevice
, if that is given, but with dtypes unchanged. Whennon_blocking
is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.See below for examples.
Note
This method modifies the module in-place.
- Parameters:
device (
torch.device
) – the desired device of the parameters and buffers in this moduledtype (
torch.dtype
) – the desired floating point or complex dtype of the parameters and buffers in this moduletensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module
memory_format (
torch.memory_format
) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)
- Returns:
self
- Return type:
Module
Examples:
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- earth2mip.regrid.get_regridder(src, dest)#
- Parameters:
src (LatLonGrid) –
dest (LatLonGrid) –
- Return type:
Module
earth2mip.schema module#
- class earth2mip.schema.EnsembleRun(*, weather_model, simulation_length, perturbation_strategy=PerturbationStrategy.correlated, perturbation_channels=None, noise_reddening=2.0, noise_amplitude=0.05, output_frequency=1, output_grid=None, ensemble_members=1, seed=1, ensemble_batch_size=1, forecast_name=None, weather_event=None, output_dir=None, output_path=None, restart_frequency=None, grf_noise_alpha=2.0, grf_noise_sigma=5.0, grf_noise_tau=2.0)#
Bases:
BaseModel
A configuration for running an ensemble weather forecast
- Parameters:
weather_model (str) –
simulation_length (int) –
perturbation_strategy (PerturbationStrategy) –
perturbation_channels (List[str] | None) –
noise_reddening (float) –
noise_amplitude (float) –
output_frequency (int) –
output_grid (Grid | None) –
ensemble_members (int) –
seed (int) –
ensemble_batch_size (int) –
forecast_name (str | None) –
weather_event (WeatherEvent | None) –
output_dir (str | None) –
output_path (str | None) –
restart_frequency (int | None) –
grf_noise_alpha (float) –
grf_noise_sigma (float) –
grf_noise_tau (float) –
- weather_model#
The name of the fully convolutional neural network (FCN) model to use for the forecast.
- Type:
str
- ensemble_members#
The number of ensemble members to use in the forecast.
- Type:
int
- noise_amplitude#
The amplitude of the Gaussian noise to add to the initial conditions.
- Type:
float
- noise_reddening#
The noise reddening amplitude, 2.0 was the defualt set by A.G. work.
- Type:
float
- simulation_length#
The length of the simulation in timesteps.
- Type:
int
- output_frequency#
The frequency at which to write the output to file, in timesteps.
- Type:
int
- use_cuda_graphs#
Whether to use CUDA graphs to optimize the computation.
- seed#
The random seed for the simulation.
- Type:
int
- ensemble_batch_size#
The batch size to use for the ensemble.
- Type:
int
- autocast_fp16#
Whether to use automatic mixed precision (AMP) with FP16 data types.
- perturbation_strategy#
The strategy to use for perturbing the initial conditions.
- perturbation_channels#
channel(s) perturbed by the initial condition perturbation strategy, None = all channels
- Type:
List[str] | None
- forecast_name#
The name of the forecast to use (alternative to weather_event).
- Type:
optional
- weather_event#
The weather event to use for the forecast (alternative to forecast_name).
- Type:
optional
- output_dir#
The directory to save the output files in (alternative to output_path).
- Type:
optional
- output_path#
The path to the output file (alternative to output_dir).
- Type:
optional
- restart_frequency#
if provided save at end and at the specified frequency. 0 = only save at end.
- Type:
int | None
- grf_noise_alpha#
tuning parameter of the Gaussian random field, see ensemble_utils.generate_noise_grf for details
- Type:
float
- grf_noise_sigma#
tuning parameter of the Gaussian random field, see ensemble_utils.generate_noise_grf for details
- Type:
float
- grf_noise_tau#
tuning parameter of the Gaussian random field, see ensemble_utils.generate_noise_grf for details
- Type:
float
- ensemble_batch_size: int#
- ensemble_members: int#
- forecast_name: str | None#
- get_weather_event()#
- Return type:
- grf_noise_alpha: float#
- grf_noise_sigma: float#
- grf_noise_tau: float#
- noise_amplitude: float#
- noise_reddening: float#
- output_dir: str | None#
- output_frequency: int#
- output_grid: Grid | None#
- output_path: str | None#
- perturbation_channels: List[str] | None#
- perturbation_strategy: PerturbationStrategy#
- restart_frequency: int | None#
- seed: int#
- simulation_length: int#
- weather_event: WeatherEvent | None#
- weather_model: str#
- class earth2mip.schema.InferenceEntrypoint(*, name='', kwargs=None)#
Bases:
BaseModel
- Attrs:
- name: an entrypoint string like
my_package:model_entrypoint
. this points to a function
model_entrypoint(package)
which returns anInference
object given a package
kwargs: the arguments to pass to the constructor
- name: an entrypoint string like
- Parameters:
name (str) –
kwargs (Mapping[Any, Any]) –
- kwargs: Mapping[Any, Any]#
- name: str#
- class earth2mip.schema.InitialConditionSource(value)#
Bases:
Enum
An enumeration.
- cds: str = 'cds'#
- era5: str = 'era5'#
- gfs: str = 'gfs'#
- hrmip: str = 'hrmip'#
- ifs: str = 'ifs'#
- class earth2mip.schema.PerturbationStrategy(value)#
Bases:
Enum
An enumeration.
- bred_vector = 'bred_vector'#
- gaussian = 'gaussian'#
- none = 'none'#
- spherical_grf = 'spherical_grf'#
- class earth2mip.schema.WeatherEvent(*, properties, domains)#
Bases:
BaseModel
- Parameters:
properties (WeatherEventProperties) –
domains (List[Window | CWBDomain | MultiPoint]) –
- domains: List[Window | CWBDomain | MultiPoint]#
- properties: WeatherEventProperties#
earth2mip.score_ensemble_outputs module#
- earth2mip.score_ensemble_outputs.main(input_path, output_path=None, time_averaging_window='', score=True, save_ensemble=False)#
- Parameters:
input_path (str) –
output_path (str | None) –
time_averaging_window (str) –
score (bool) –
save_ensemble (bool) –
- Return type:
None
- earth2mip.score_ensemble_outputs.open_ensemble(path, group)#
- earth2mip.score_ensemble_outputs.open_verification(time)#
- earth2mip.score_ensemble_outputs.read_weather_event(dir)#
- earth2mip.score_ensemble_outputs.save_dataset(out, path)#
earth2mip.time module#
- earth2mip.time.convert_to_datetime(time)#
- Return type:
datetime
- earth2mip.time.datetime_to_timestamp(time)#
- Parameters:
time (datetime) –
- Return type:
float
earth2mip.time_collection module#
- earth2mip.time_collection.run_over_initial_times(*, time_loop, data_source, initial_times, config, output_path, time_averaging_window='', score=False, save_ensemble=False, shard=0, n_shards=1, n_post_processing_workers=32)#
Perform a set of forecasts across many initial conditions in parallel with post processing
Once complete, the data at
output_path
can be opened as an xarray object usingearth2mip.datasets.hindcast.open_forecast()
.Parallelizes across the available GPUs using MPI, and can be further parallelized across multiple MPI jobs using the
shard
/n_shards
flags. It can be resumed after interruption.- Parameters:
time_loop (TimeLoop) – the earth2mip TimeLoop to be evaluated. Often returned by earth2mip.networks.get_model
data_source (DataSource | None) – the data source used to initialize the time_loop, overrides any data source specified in
config
initial_times (list[datetime]) – the initial times evaluated over
n_shards (int) – split the input times into this many shards
time_averaging_window (str) – if provided, average the output over this interval. Same syntax as pandas.Timedelta (e.g. “2w”). Default is no time averaging.
score (bool) – if true, score the times during the post processing
save_ensemble (bool) – if true, then save all the ensemble members in addition to the mean
shard (int) – index of the shard. useful for SLURM array jobs
n_shards – number of shards total.
n_post_processing_workers (int) – The number of dask distributed workers to devote to ensemble post processing.
config (EnsembleRun) –
output_path (str) –
- Return type:
None
earth2mip.time_loop module#
- class earth2mip.time_loop.GeoTensorInfo(channel_names, grid, n_history_levels=1, history_time_step=datetime.timedelta(0))#
Bases:
object
Metadata explaining how tensor maps onto the Earth
Describes a tensor
x
with shape(batch, history, channel, lat, lon)
.- Parameters:
channel_names (List[str]) –
grid (LatLonGrid) –
n_history_levels (int) –
history_time_step (timedelta) –
- channel_names: List[str]#
- grid: LatLonGrid#
- history_time_step: timedelta = datetime.timedelta(0)#
- n_history_levels: int = 1#
- class earth2mip.time_loop.TimeLoop(*args, **kwargs)#
Bases:
Protocol
Abstract protocol that a custom time loop must follow
This is a callable which yields time and output information. Some attributes are required to define the input and output data required.
The expectation is that this class and the data passed to it are on the same device. While torch modules can be moved between devices easily, this is not true for all frameworks.
- in_channel_names#
- Type:
List[str]
- out_channel_names#
- Type:
List[str]
- grid#
- n_history_levels#
- Type:
int
- history_time_step#
- Type:
datetime.timedelta
- time_step#
- Type:
datetime.timedelta
- device#
- Type:
torch.device
- device: device#
- dtype: dtype = torch.float32#
- grid: LatLonGrid#
- history_time_step: timedelta = datetime.timedelta(0)#
- in_channel_names: List[str]#
- n_history_levels: int = 1#
- out_channel_names: List[str]#
- time_step: timedelta#
- class earth2mip.time_loop.TimeStepper(*args, **kwargs)#
Bases:
Protocol
[StateT
]An functional interface that can be used for time stepping
state -> (state, output)
This uses a generic state, but concrete Tensors as input and output. This allows users to directly control the time-stepping logic and potentially modify the state in model-specific manner, but the basic initial conditions and running outputs are concrete torch Tensors.
One example is the graphcast time stepper. Graphcast uses jax and xarray to handle the state.
It should be used like this:
stepper = MyStepper() state = stepper.initialize(x, time) outputs = [] for i in range(10): state, output = stepper.step(state) outputs.append(output)
One benefit is that the state can be saved and reloaded trivially to restart the simulation.
- property device: device#
- property dtype: device#
- initialize(x, time)#
x is described by
self.input_info
- Parameters:
x (Tensor) –
time (datetime) –
- Return type:
StateT
- property input_info: GeoTensorInfo#
- property output_info: GeoTensorInfo#
- step(state)#
step the state and return the ml output as a tensor
The output tensor is described by
self.output_info
- Parameters:
state (StateT) –
- Return type:
tuple[StateT, Tensor]
- property time_step: timedelta#
- class earth2mip.time_loop.TimeStepperLoop(stepper)#
Bases:
TimeLoop
Turn a TimeStepper into a TimeLoop
- Parameters:
stepper (TimeStepper) –
- property device: device#
- property dtype: dtype#
- property grid: LatLonGrid#
- property history_time_step: timedelta#
Difference between two datetime values.
timedelta(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0)
All arguments are optional and default to 0. Arguments may be integers or floats, and may be positive or negative.
- property in_channel_names: List[str]#
- property n_history_levels: int#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.__int__(). For floating point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by ‘+’ or ‘-’ and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal. >>> int(‘0b100’, base=0) 4
- property out_channel_names: List[str]#
- property time_step: timedelta#
earth2mip.weather_events module#
- class earth2mip.weather_events.CWBDomain(*, type, name, path='/lustre/fsw/sw_climate_fno/nbrenowitz/2023-01-24-cwb-4years.zarr', diagnostics)#
Bases:
BaseModel
- Parameters:
type (Literal['CWBDomain']) –
name (str) –
path (str) –
diagnostics (List[Diagnostic]) –
- diagnostics: List[Diagnostic]#
- name: str#
- path: str#
- type: Literal['CWBDomain']#
- class earth2mip.weather_events.Diagnostic(*, type, function='', channels, nbins=10)#
Bases:
BaseModel
- Parameters:
type (str) –
function (str) –
channels (List[str]) –
nbins (int) –
- channels: List[str]#
- function: str#
- nbins: int#
- type: str#
- class earth2mip.weather_events.InitialConditionSource(value)#
Bases:
Enum
An enumeration.
- cds: str = 'cds'#
- era5: str = 'era5'#
- gfs: str = 'gfs'#
- hrmip: str = 'hrmip'#
- ifs: str = 'ifs'#
- class earth2mip.weather_events.MultiPoint(*, type, name, lat, lon, diagnostics)#
Bases:
BaseModel
- Parameters:
type (Literal['MultiPoint']) –
name (str) –
lat (List[float]) –
lon (List[float]) –
diagnostics (List[Diagnostic]) –
- diagnostics: List[Diagnostic]#
- lat: List[float]#
- lon: List[float]#
- name: str#
- type: Literal['MultiPoint']#
- class earth2mip.weather_events.WeatherEvent(*, properties, domains)#
Bases:
BaseModel
- Parameters:
properties (WeatherEventProperties) –
domains (List[Window | CWBDomain | MultiPoint]) –
- domains: List[Window | CWBDomain | MultiPoint]#
- properties: WeatherEventProperties#
- class earth2mip.weather_events.WeatherEventProperties(*, name, start_time=None, initial_condition_source=InitialConditionSource.era5, netcdf='', restart='')#
Bases:
BaseModel
- Parameters:
name (str) –
start_time (datetime | None) –
initial_condition_source (InitialConditionSource) –
netcdf (str) –
restart (str) –
- netcdf#
load the initial conditions from this path if given
- Type:
str
- initial_condition_source: InitialConditionSource#
- name: str#
- netcdf: str#
- restart: str#
- start_time: datetime | None#
- class earth2mip.weather_events.Window(*, type='Window', name, lat_min=-90, lat_max=90, lon_min=0, lon_max=360, diagnostics)#
Bases:
BaseModel
- Parameters:
type (Literal['Window']) –
name (str) –
lat_min (float) –
lat_max (float) –
lon_min (float) –
lon_max (float) –
diagnostics (List[Diagnostic]) –
- diagnostics: List[Diagnostic]#
- lat_max: float#
- lat_min: float#
- lon_max: float#
- lon_min: float#
- name: str#
- type: Literal['Window']#
- earth2mip.weather_events.list_()#
- earth2mip.weather_events.read(forecast_name)#
- Parameters:
forecast_name (str) –
- Return type: