cot.transport_torch

cot.transport_torch(DA, SB, C, eps, device)[source]

This function sloves the additive approximation of optimal transport problem between two discrete distributions and returns the transports plan, dual variables and total cost. This function is a PyTorch implementation version of the parallelizable combinatorial algorithm [2] for OT.

Parameters:
  • DA (tensor, shape (dim_a,)) – A dim_a by 1 tensor, the weight of samples from the demand distribution (type a), each DA[i] represents the mass of demand on i-th type a vertex. The sum of DA should equal to 1.

  • SB (tensor, shape (dim_b,)) – A dim_a by 1 tensor, the weight of samples from the source distribution(type b), each SB[i] represents the mass of supply on i-th type b vertex. The sum of SB should equal to 1.

  • C (tensor, shape (dim_b, dim_a)) – A dim_b by dim_a cost matrix, each C(i,j) represents the cost between i-th type b and j-th type a vertex.

  • eps (float) – The additive error of optimal transport distance, the value of \(\epsilon\) in paper [2].

  • device (torch.device) – The device where the computation will be executed. (e.g. torch.device(‘cuda:0’) for GPU)

Returns:

  • F (tensor, shape (dim_b, dim_a)) – A dim_b by dim_a tensor, F(i,j) represents the flow (transport plan) between i-th type b and j-th type a vertex.

  • yA (tensor, shape (dim_a,)) – A 1 by dim_a array, each yA[i] represents the final dual weight of i-th type a vertex.

  • yB (tensor, shape (dim_b,)) – A 1 by dim_b array, each yB[i] represents the final dual value of i-th type b vertex.

  • total_cost (float) – The total cost of the final additive approximate optimal transport plan.

References