tacco.utils.solve_OT¶
- solve_OT(a, b, M, epsilon=0.005, lambda_a=None, lambda_b=None, numItermax=1000, stopThr=1e-09, inplace=False)[source]¶
Solve optimal transport problem with entropy regularization and optionally Kullback-Leibler divergence penalty terms instead of exact marginal conservation. The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm from in [Chizat16]. The python implementation is based on [Flamary21].
- Parameters:
a – A 1d array-like containing the left marginal distribution
b – A 1d array-like containing the right marginal distribution
M – A 2d array-like containing the loss matrix
epsilon – The entropy regularization parameter
lambda_a – The left marginal relaxation parameter; if None, enforce marginal exactly like in balanced OT
lambda_b – The right marginal relaxation parameter; if None, enforce marginal exactly like in balanced OT
numItermax – The maximal number of iterations
stopThr – The error threshold for the stopping criterion
inplace – Whether M will contain the transort matrix upon completion or be unchanged. M has to be a
ndarray
with dtype=np.float64 for this.
- Returns:
Depending on inplace returns a 2d
ndarray
containing the transport couplings, which either is the inplace updated M or a newly allocated array.