earth2mip.diagnostic package#

earth2mip.diagnostic.base module#

class earth2mip.diagnostic.base.DiagnosticBase(*args, **kwargs)#

Bases: Module, GeoOperator

Diagnostic model base class

classmethod load_config_type(*args, **kwargs)#

Class function used access the Pydantic config class of the diagnostic if one has been implemented. Note this returns a class reference, not instantiated object.

Return type:

Type[BaseModel]

classmethod load_diagnostic(package, *args, **kwargs)#

Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.

Note

This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.

Parameters:

package (Package | None) –

classmethod load_package(*args, **kwargs)#

Class function used to create the diagnostic model package (if needed). This should be where any explicit download functions should be orcastrated

Return type:

Package

earth2mip.diagnostic.climate_net module#

class earth2mip.diagnostic.climate_net.BNPReLU(nOut)#

Bases: Module

Batch Norm with PReLU layer.

forward(input)#
Parameters:
  • input – input feature map

  • return – normalized and thresholded feature map

class earth2mip.diagnostic.climate_net.CGNetModule(classes=19, channels=4, M=3, N=21, dropout_flag=False)#

Bases: Module

CGNet (Wu et al, 2018: https://arxiv.org/pdf/1811.08201.pdf) implementation. This is taken from their implementation, we do not claim credit for this.

forward(input)#
Parameters:
  • input – Receives the input RGB image

  • return – segmentation map

class earth2mip.diagnostic.climate_net.ChannelWiseConv(nIn, nOut, kSize, stride=1)#

Bases: Module

Channel Wise convolutional layer.

forward(input)#
Parameters:
  • input – input feature map

  • return – transformed feature map

class earth2mip.diagnostic.climate_net.ChannelWiseDilatedConv(nIn, nOut, kSize, stride=1, d=1)#

Bases: Module

Channel-wise Convolutional Layer with Dilation.

forward(input)#
Parameters:
  • input – input feature map

  • return – transformed feature map

class earth2mip.diagnostic.climate_net.ClimateNet(model, in_center, in_scale)#

Bases: DiagnosticBase

Climate Net Diagnostic model, built into Earth-2 MIP. This model can be used to create prediction labels for tropical cyclones and atmopheric rivers. Produces non-standard output channels climnet_bg, climnet_tc and climnet_ar representing background label, tropical cyclone and atmopheric river labels.

Note

This model and checkpoint are from Prabhat et al. 2021 https://doi.org/10.5194/gmd-14-107-2021 andregraubner/ClimateNet

Example

>>> package = ClimateNet.load_package()
>>> model = ClimateNet.load_diagnostic(package)
>>> x = torch.randn(1, 4, 721, 1440)
>>> out = model(x)
>>> out.shape
(1, 3, 721, 1440)
Parameters:
  • model (Module) –

  • in_center (Tensor) –

  • in_scale (Tensor) –

property in_channel_names: list[str]#
property in_grid: LatLonGrid#
classmethod load_diagnostic(package, device='cuda:0')#

Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.

Note

This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.

Parameters:

package (Package) –

classmethod load_package(registry='/root/.cache/earth2mip/models/diagnostics')#

Class function used to create the diagnostic model package (if needed). This should be where any explicit download functions should be orcastrated

Parameters:

registry (str) –

Return type:

Package

property out_channel_names: list[str]#
property out_grid: LatLonGrid#
class earth2mip.diagnostic.climate_net.ContextGuidedBlock(nIn, nOut, dilation_rate=2, reduction=16, add=True)#

Bases: Module

Context Guided Block.

forward(input)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.diagnostic.climate_net.ContextGuidedBlock_Down(nIn, nOut, dilation_rate=2, reduction=16)#

Bases: Module

The size of feature map divided 2, (H,W,C)—->(H/2, W/2, 2C)

forward(input)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.diagnostic.climate_net.Conv(nIn, nOut, kSize, stride=1)#

Bases: Module

Ordinary Convolutional Layer.

forward(input)#
Parameters:
  • input – input feature map

  • return – transformed feature map

class earth2mip.diagnostic.climate_net.ConvBN(nIn, nOut, kSize, stride=1)#

Bases: Module

Convolutional Layer with Batch Norm.

forward(input)#
Parameters:
  • input – input feature map

  • return – transformed feature map

class earth2mip.diagnostic.climate_net.ConvBNPReLU(nIn, nOut, kSize, stride=1)#

Bases: Module

Convolutional Net with Batch Norm and PreLU.

forward(input)#
Parameters:
  • input – input feature map

  • return – transformed feature map

class earth2mip.diagnostic.climate_net.DilatedConv(nIn, nOut, kSize, stride=1, d=1)#

Bases: Module

Convolutional Layer with Dilation.

forward(input)#
Parameters:
  • input – input feature map

  • return – transformed feature map

class earth2mip.diagnostic.climate_net.FGlo(channel, reduction=16)#

Bases: Module

The FGlo class is employed to refine the joint feature of both local feature and surrounding context.

forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.diagnostic.climate_net.InputInjection(downsamplingRatio)#

Bases: Module

Inject Input with pooling.

forward(input)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.diagnostic.climate_net.Wrap(padding)#

Bases: Module

Climate net wrapper for padding.

forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

earth2mip.diagnostic.precipitation_afno module#

class earth2mip.diagnostic.precipitation_afno.PeriodicPad2d(pad_width)#

Bases: Module

forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.diagnostic.precipitation_afno.PrecipNet(*args, **kwargs)#

Bases: Module

forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class earth2mip.diagnostic.precipitation_afno.PrecipitationAFNO(model, in_center, in_scale)#

Bases: DiagnosticBase

Precipitation AFNO model. Predicts the total precipation parameter which is the accumulated amount of liquid and frozen water (rain or snow) with units m.

Note

This checkpoint is from Parthik et al. 2022. https://arxiv.org/abs/2202.11214 NVlabs/FourCastNet

Example

>>> package = PrecipAFNO.load_package()
>>> model = PrecipAFNO.load_diagnostic(package)
>>> x = torch.randn(1, 4, 720, 1440)
>>> out = model(x)
>>> out.shape
(1, 1, 721, 1440)
Parameters:
  • model (Module) –

  • in_center (Tensor) –

  • in_scale (Tensor) –

property in_channel_names: list[str]#
property in_grid: LatLonGrid#
classmethod load_diagnostic(package, device='cuda:0')#

Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.

Note

This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.

Parameters:

package (Package) –

classmethod load_package(registry='/root/.cache/earth2mip/models/diagnostics')#

Class function used to create the diagnostic model package (if needed). This should be where any explicit download functions should be orcastrated

Parameters:

registry (str) –

Return type:

Package

property out_channel_names: list[str]#
property out_grid: LatLonGrid#

earth2mip.diagnostic.time_loop module#

class earth2mip.diagnostic.time_loop.DiagnosticTimeLoop(diagnostics, model, concat=True)#

Bases: TimeLoop

Diagnostic Timeloop. This is an iterator that executes a list of diagnostic

models on top of a model Timeloop.

Note

Presently, grids must be consistent between diagnostics and the model

Parameters:
  • diagnostics (List[DiagnosticBase]) – List of diagnostic functions to execute

  • model (TimeLoop) – Model inferencer iterator

  • concat (bool, optional) – Concatentate diagnostic outputs with model outputs.

  • True. (Defaults to) –

property device#
property grid#
property in_channel_names#
property out_channel_names#

earth2mip.diagnostic.utils module#

earth2mip.diagnostic.utils.filter_channels(input, in_channels, out_channels)#

Utility function used for selecting a sub set of channels

Note

Right now this assumes that the channels are in the thirds to last axis.

Parameters:
  • input (torch.Tensor) – Input tensor of shape […, channels, lat, lon]

  • in_channels (list[str]) – Input channel list

  • out_channels (list[str]) – Output channel list

Return type:

Tensor

earth2mip.diagnostic.wind_speed module#

class earth2mip.diagnostic.wind_speed.WindSpeed(level, grid)#

Bases: DiagnosticBase

Computes the wind speed at a given level. This is largely just an example of what a diagnostic calculation could look like.

Example

>>> windspeed = WindSpeed('10m', Grid.grid_721x1440)
>>> x = torch.randn(1, 2, 721, 1440)
>>> out = windspeed(x)
>>> out.shape
(1, 1, 721, 1440)
Parameters:
property in_channel_names: list[str]#
property in_grid: LatLonGrid#
classmethod load_diagnostic(package, level, grid)#

Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.

Note

This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.

Parameters:
property out_channel_names: list[str]#
property out_grid: LatLonGrid#

Module contents#