Source code for parts.centaur.batch_norm
# Copyright (c) 2019 NVIDIA Corporation
import tensorflow as tf
[docs]class BatchNorm1D:
"""
1D batch normalization layer.
"""
def __init__(self, *args, **kwargs):
super(BatchNorm1D, self).__init__()
self.norm = tf.layers.BatchNormalization(*args, **kwargs)
def __call__(self, x, training):
with tf.variable_scope("batch_norm_1d"):
y = tf.expand_dims(x, axis=1)
y = self.norm(y, training=training)
y = tf.squeeze(y, axis=1)
return y