commit
89ccf50bac
|
@ -70,6 +70,13 @@ class LocalizationNetwork(object):
|
|||
return initial_bias
|
||||
|
||||
def __call__(self, image):
|
||||
"""
|
||||
Estimating parameters of geometric transformation
|
||||
Args:
|
||||
image: input
|
||||
Return:
|
||||
batch_C_prime: the matrix of the geometric transformation
|
||||
"""
|
||||
F = self.F
|
||||
loc_lr = self.loc_lr
|
||||
if self.model_name == "large":
|
||||
|
@ -215,6 +222,14 @@ class GridGenerator(object):
|
|||
return batch_C_ex_part_tensor
|
||||
|
||||
def __call__(self, batch_C_prime, I_r_size):
|
||||
"""
|
||||
Generate the grid for the grid_sampler.
|
||||
Args:
|
||||
batch_C_prime: the matrix of the geometric transformation
|
||||
I_r_size: the shape of the input image
|
||||
Return:
|
||||
batch_P_prime: the grid for the grid_sampler
|
||||
"""
|
||||
C = self.build_C()
|
||||
P = self.build_P(I_r_size)
|
||||
inv_delta_C = self.build_inv_delta_C(C).astype('float32')
|
||||
|
|
Loading…
Reference in New Issue