cot.transport_torch
- cot.transport_torch(DA, SB, C, eps, device)[source]
This function sloves the additive approximation of optimal transport problem between two discrete distributions and returns the transports plan, dual variables and total cost. This function is a PyTorch implementation version of the parallelizable combinatorial algorithm [2] for OT.
- Parameters:
DA (tensor, shape (dim_a,)) – A dim_a by 1 tensor, the weight of samples from the demand distribution (type a), each DA[i] represents the mass of demand on i-th type a vertex. The sum of DA should equal to 1.
SB (tensor, shape (dim_b,)) – A dim_a by 1 tensor, the weight of samples from the source distribution(type b), each SB[i] represents the mass of supply on i-th type b vertex. The sum of SB should equal to 1.
C (tensor, shape (dim_b, dim_a)) – A dim_b by dim_a cost matrix, each C(i,j) represents the cost between i-th type b and j-th type a vertex.
eps (float) – The additive error of optimal transport distance, the value of \(\epsilon\) in paper [2].
device (torch.device) – The device where the computation will be executed. (e.g. torch.device(‘cuda:0’) for GPU)
- Returns:
F (tensor, shape (dim_b, dim_a)) – A dim_b by dim_a tensor, F(i,j) represents the flow (transport plan) between i-th type b and j-th type a vertex.
yA (tensor, shape (dim_a,)) – A 1 by dim_a array, each yA[i] represents the final dual weight of i-th type a vertex.
yB (tensor, shape (dim_b,)) – A 1 by dim_b array, each yB[i] represents the final dual value of i-th type b vertex.
total_cost (float) – The total cost of the final additive approximate optimal transport plan.
References