earth2mip.diagnostic package#
earth2mip.diagnostic.base module#
- class earth2mip.diagnostic.base.DiagnosticBase(*args, **kwargs)#
Bases:
Module
,GeoOperator
Diagnostic model base class
- classmethod load_config_type(*args, **kwargs)#
Class function used access the Pydantic config class of the diagnostic if one has been implemented. Note this returns a class reference, not instantiated object.
- Return type:
Type[BaseModel]
- classmethod load_diagnostic(package, *args, **kwargs)#
Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.
Note
This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.
- Parameters:
package (Package | None) –
earth2mip.diagnostic.climate_net module#
- class earth2mip.diagnostic.climate_net.BNPReLU(nOut)#
Bases:
Module
Batch Norm with PReLU layer.
- forward(input)#
- Parameters:
input – input feature map
return – normalized and thresholded feature map
- class earth2mip.diagnostic.climate_net.CGNetModule(classes=19, channels=4, M=3, N=21, dropout_flag=False)#
Bases:
Module
CGNet (Wu et al, 2018: https://arxiv.org/pdf/1811.08201.pdf) implementation. This is taken from their implementation, we do not claim credit for this.
- forward(input)#
- Parameters:
input – Receives the input RGB image
return – segmentation map
- class earth2mip.diagnostic.climate_net.ChannelWiseConv(nIn, nOut, kSize, stride=1)#
Bases:
Module
Channel Wise convolutional layer.
- forward(input)#
- Parameters:
input – input feature map
return – transformed feature map
- class earth2mip.diagnostic.climate_net.ChannelWiseDilatedConv(nIn, nOut, kSize, stride=1, d=1)#
Bases:
Module
Channel-wise Convolutional Layer with Dilation.
- forward(input)#
- Parameters:
input – input feature map
return – transformed feature map
- class earth2mip.diagnostic.climate_net.ClimateNet(model, in_center, in_scale)#
Bases:
DiagnosticBase
Climate Net Diagnostic model, built into Earth-2 MIP. This model can be used to create prediction labels for tropical cyclones and atmopheric rivers. Produces non-standard output channels climnet_bg, climnet_tc and climnet_ar representing background label, tropical cyclone and atmopheric river labels.
Note
This model and checkpoint are from Prabhat et al. 2021 https://doi.org/10.5194/gmd-14-107-2021 andregraubner/ClimateNet
Example
>>> package = ClimateNet.load_package() >>> model = ClimateNet.load_diagnostic(package) >>> x = torch.randn(1, 4, 721, 1440) >>> out = model(x) >>> out.shape (1, 3, 721, 1440)
- Parameters:
model (Module) –
in_center (Tensor) –
in_scale (Tensor) –
- property in_channel_names: list[str]#
- property in_grid: LatLonGrid#
- classmethod load_diagnostic(package, device='cuda:0')#
Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.
Note
This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.
- Parameters:
package (Package) –
- classmethod load_package(registry='/root/.cache/earth2mip/models/diagnostics')#
Class function used to create the diagnostic model package (if needed). This should be where any explicit download functions should be orcastrated
- Parameters:
registry (str) –
- Return type:
- property out_channel_names: list[str]#
- property out_grid: LatLonGrid#
- class earth2mip.diagnostic.climate_net.ContextGuidedBlock(nIn, nOut, dilation_rate=2, reduction=16, add=True)#
Bases:
Module
Context Guided Block.
- forward(input)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.diagnostic.climate_net.ContextGuidedBlock_Down(nIn, nOut, dilation_rate=2, reduction=16)#
Bases:
Module
The size of feature map divided 2, (H,W,C)—->(H/2, W/2, 2C)
- forward(input)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.diagnostic.climate_net.Conv(nIn, nOut, kSize, stride=1)#
Bases:
Module
Ordinary Convolutional Layer.
- forward(input)#
- Parameters:
input – input feature map
return – transformed feature map
- class earth2mip.diagnostic.climate_net.ConvBN(nIn, nOut, kSize, stride=1)#
Bases:
Module
Convolutional Layer with Batch Norm.
- forward(input)#
- Parameters:
input – input feature map
return – transformed feature map
- class earth2mip.diagnostic.climate_net.ConvBNPReLU(nIn, nOut, kSize, stride=1)#
Bases:
Module
Convolutional Net with Batch Norm and PreLU.
- forward(input)#
- Parameters:
input – input feature map
return – transformed feature map
- class earth2mip.diagnostic.climate_net.DilatedConv(nIn, nOut, kSize, stride=1, d=1)#
Bases:
Module
Convolutional Layer with Dilation.
- forward(input)#
- Parameters:
input – input feature map
return – transformed feature map
- class earth2mip.diagnostic.climate_net.FGlo(channel, reduction=16)#
Bases:
Module
The FGlo class is employed to refine the joint feature of both local feature and surrounding context.
- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.diagnostic.climate_net.InputInjection(downsamplingRatio)#
Bases:
Module
Inject Input with pooling.
- forward(input)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.diagnostic.climate_net.Wrap(padding)#
Bases:
Module
Climate net wrapper for padding.
- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
earth2mip.diagnostic.precipitation_afno module#
- class earth2mip.diagnostic.precipitation_afno.PeriodicPad2d(pad_width)#
Bases:
Module
- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.diagnostic.precipitation_afno.PrecipNet(*args, **kwargs)#
Bases:
Module
- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class earth2mip.diagnostic.precipitation_afno.PrecipitationAFNO(model, in_center, in_scale)#
Bases:
DiagnosticBase
Precipitation AFNO model. Predicts the total precipation parameter which is the accumulated amount of liquid and frozen water (rain or snow) with units m.
Note
This checkpoint is from Parthik et al. 2022. https://arxiv.org/abs/2202.11214 NVlabs/FourCastNet
Example
>>> package = PrecipAFNO.load_package() >>> model = PrecipAFNO.load_diagnostic(package) >>> x = torch.randn(1, 4, 720, 1440) >>> out = model(x) >>> out.shape (1, 1, 721, 1440)
- Parameters:
model (Module) –
in_center (Tensor) –
in_scale (Tensor) –
- property in_channel_names: list[str]#
- property in_grid: LatLonGrid#
- classmethod load_diagnostic(package, device='cuda:0')#
Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.
Note
This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.
- Parameters:
package (Package) –
- classmethod load_package(registry='/root/.cache/earth2mip/models/diagnostics')#
Class function used to create the diagnostic model package (if needed). This should be where any explicit download functions should be orcastrated
- Parameters:
registry (str) –
- Return type:
- property out_channel_names: list[str]#
- property out_grid: LatLonGrid#
earth2mip.diagnostic.time_loop module#
- class earth2mip.diagnostic.time_loop.DiagnosticTimeLoop(diagnostics, model, concat=True)#
Bases:
TimeLoop
- Diagnostic Timeloop. This is an iterator that executes a list of diagnostic
models on top of a model Timeloop.
Note
Presently, grids must be consistent between diagnostics and the model
- Parameters:
diagnostics (List[DiagnosticBase]) – List of diagnostic functions to execute
model (TimeLoop) – Model inferencer iterator
concat (bool, optional) – Concatentate diagnostic outputs with model outputs.
True. (Defaults to) –
- property device#
- property grid#
- property in_channel_names#
- property out_channel_names#
earth2mip.diagnostic.utils module#
- earth2mip.diagnostic.utils.filter_channels(input, in_channels, out_channels)#
Utility function used for selecting a sub set of channels
Note
Right now this assumes that the channels are in the thirds to last axis.
- Parameters:
input (torch.Tensor) – Input tensor of shape […, channels, lat, lon]
in_channels (list[str]) – Input channel list
out_channels (list[str]) – Output channel list
- Return type:
Tensor
earth2mip.diagnostic.wind_speed module#
- class earth2mip.diagnostic.wind_speed.WindSpeed(level, grid)#
Bases:
DiagnosticBase
Computes the wind speed at a given level. This is largely just an example of what a diagnostic calculation could look like.
Example
>>> windspeed = WindSpeed('10m', Grid.grid_721x1440) >>> x = torch.randn(1, 2, 721, 1440) >>> out = windspeed(x) >>> out.shape (1, 1, 721, 1440)
- Parameters:
level (str) –
grid (LatLonGrid) –
- property in_channel_names: list[str]#
- property in_grid: LatLonGrid#
- classmethod load_diagnostic(package, level, grid)#
Class function used to load the diagnostic model onto device memory and create an instance of the diagnostic for use.
Note
This function always accepts a package as the first argument. For many function this may be irrelevant or could be None, but the requirement persists for consistent API and data flow.
- Parameters:
package (Package | None) –
level (str) –
grid (LatLonGrid) –
- property out_channel_names: list[str]#
- property out_grid: LatLonGrid#