Source code for cot.transport

import numpy as np
import torch

def flow_validation(F, f, FreeA, FreeA_ori, FreeB, FreeB_ori):
    # check if flow is valid
    if (FreeA < 0).any():
        print("negative FreeA")
    if (FreeB < 0).any():
        print("negative FreeB")
    if (F < 0).any():
        print("negative flow")    
    if (torch.sum(F,0) + FreeA != FreeA_ori).any():
        print("flow not valid for type A")
    if (torch.sum(F,1) + FreeB != FreeB_ori).any():
        print("flow not valid for type B")
    if f != torch.sum(FreeB):
        print("flow not valid for free type B")

def flow_validation_final(F, FreeA, FreeA_ori, FreeB, FreeB_ori, device):
    # check if flow is valid
    if (FreeA < -1e-9).any():
        print("negative FreeA")
    if (FreeB < -1e-9).any():
        print("negative FreeB")
    if (F < 0).any():
        print("negative flow")  
    # torch.testing.assert_close(torch.sum(F,0) + FreeA - FreeA_ori, torch.zeros(FreeA.shape, dtype=torch.float64), msg="Flow not valid for type A!")
    # torch.testing.assert_close(torch.sum(F,1) + FreeB - FreeB_ori, torch.zeros(FreeB.shape, dtype=torch.float64), msg="Flow not valid for type B!")
    torch.testing.assert_close(torch.sum(F,0) - FreeA_ori, torch.zeros(FreeA.shape, dtype=torch.float64, device = device), msg="Flow not valid for type A!")
    torch.testing.assert_close(torch.sum(F,1) - FreeB_ori, torch.zeros(FreeB.shape, dtype=torch.float64, device = device), msg="Flow not valid for type B!")

def feasibilty_validation(yFA, yB, yA, F, C):
    # check feasibility
    zero_f_ind = torch.where(F == 0)
    nonzero_f_ind = torch.where(F > 0)
    if len(zero_f_ind[0]) > 0 and (yA[zero_f_ind[1]] + yB[zero_f_ind[0]] > C[zero_f_ind] + 1).any():
        print("first feasibility condition not valid (F=0)")
    if len(nonzero_f_ind[0]) > 0 and (yFA[nonzero_f_ind] + yB[nonzero_f_ind[0]] > C[nonzero_f_ind] + 1).any():
        print("first feasibility condition not valid (F>0)")
    if len(nonzero_f_ind[0]) > 0 and (yFA[nonzero_f_ind] + yB[nonzero_f_ind[0]] < C[nonzero_f_ind]).any():
        print("second feasibility condition not valid")

def slack_validation(yB, yA, S, C):
    # check slack
    # Only need to check S wi-th yA and yB
    if (S != C + 1 - yB[:,None] - yA[None, :]).any():
        print("slack not valid")

def unique2(x, input_sorted = False):
    # Returns the unique elements of array x, and the indices of the first occurrences of the unique values in the original array
    # Method 2

    unique, inverse_ind, unique_count = torch.unique(x, return_inverse=True, return_counts=True)
    unique_ind = unique_count.cumsum(0)
    if not unique_ind.size()[0] == 0:
        unique_ind = torch.cat((torch.tensor([0], dtype=x.dtype, device=x.device), unique_ind[:-1]))
    if not input_sorted:
        _, sort2ori_ind = torch.sort(inverse_ind, stable=True)
        unique_ind = sort2ori_ind[unique_ind]
    return unique, unique_ind

def subset_sum_filter(F, sum, dim=0):
    F_cum = torch.cumsum(F, dim)
    F_cum_clamp = torch.clamp(F_cum, max=sum)
    F = F - (F_cum - F_cum_clamp)
    F_mask = F.lt(0)
    F_mask_full = (F_cum.le(sum) & F.gt(0))
    F[F_mask] = 0
    return F, F_mask_full


[docs] def transport_torch(DA, SB, C, eps, device): """ This function sloves the additive approximation of optimal transport problem between two discrete distributions and returns the transports plan, dual variables and total cost. This function is a PyTorch implementation version of the parallelizable combinatorial algorithm [2]_ for OT. Parameters ---------------- DA : tensor, shape (dim_a,) A dim_a by 1 tensor, the weight of samples from the demand distribution (type a), each DA[i] represents the mass of demand on i-th type a vertex. The sum of DA should equal to 1. SB : tensor, shape (dim_b,) A dim_a by 1 tensor, the weight of samples from the source distribution(type b), each SB[i] represents the mass of supply on i-th type b vertex. The sum of SB should equal to 1. C : tensor, shape (dim_b, dim_a) A dim_b by dim_a cost matrix, each C(i,j) represents the cost between i-th type b and j-th type a vertex. eps : float The additive error of optimal transport distance, the value of :math:`\epsilon` in paper [2]_. device : torch.device The device where the computation will be executed. (e.g. torch.device('cuda:0') for GPU) Returns ---------------- F : tensor, shape (dim_b, dim_a) A dim_b by dim_a tensor, F(i,j) represents the flow (transport plan) between i-th type b and j-th type a vertex. yA : tensor, shape (dim_a,) A 1 by dim_a array, each yA[i] represents the final dual weight of i-th type a vertex. yB : tensor, shape (dim_b,) A 1 by dim_b array, each yB[i] represents the final dual value of i-th type b vertex. total_cost : float The total cost of the final additive approximate optimal transport plan. References ---------------- .. [2] Lahn Nathaniel, Sharath Raghvendra, and Kaiyi Zhang. A combinatorial algori-thm for approximating the optimal transport in the parallel and mpc settings. Advances in Neural Information Processing Systems (NeurIPS) 36, 2023 """ torch.manual_seed(0) dtyp = torch.int32 zero = torch.tensor([0], device=device, dtype=dtyp, requires_grad=False)[0] one = torch.tensor([1], device=device, dtype=dtyp, requires_grad=False)[0] m_one = torch.tensor([-1], device=device, dtype=dtyp, requires_grad=False)[0] m = torch.tensor(C.shape[0], device=device, requires_grad=False) n = torch.tensor(C.shape[1], device=device, requires_grad=False) yA = torch.zeros(n, dtype=dtyp, device=device) # smaller dual weight of type A vertex in absolute value yB = torch.ones(m, dtype=dtyp, device=device) # dual weight of type B vertex F = torch.zeros(C.shape, device=device, dtype=dtyp, requires_grad=False) yFA = torch.full(C.shape, one, device=device, dtype=dtyp, requires_grad=False) # dual weight of type A vertice that matched to certain type B vertice, default value torch.iinfo.min when no matching S = torch.div((3*C), eps, rounding_mode='trunc').type(dtyp).to(device) # C_scaled = torch.div((3*C), eps, rounding_mode='trunc').type(dtyp).to(device) # scaled cost for feasibility validation max_C = torch.max(C) alpha = 6 * n * max_C / eps FreeA_ = DA * alpha FreeA = torch.ceil(FreeA_).to(dtyp) FreeA_ori = FreeA.clone() FreeB_ = SB * alpha FreeB = FreeB_.to(dtyp) FreeB_ori = FreeB.clone() f = torch.sum(FreeB) #flow remaining to push ff = 0 iteration = 0 # main loop while f > n: # extract admissiable graph ind_b_free = torch.where(FreeB > zero) ind_zero_slack_ind = torch.where(S[ind_b_free[0],:]==0) # find push-release edges and corresponding flow # find push edges ind_b_push_tent_ind, ind_a_tent_lt_inclusive = unique2(ind_zero_slack_ind[0], input_sorted=True) ind_b_push_tent = ind_b_free[0][ind_b_push_tent_ind] #tentative B (push) ind_a_tent_rt_exclusive = torch.cat((ind_a_tent_lt_inclusive[1:], torch.tensor(ind_zero_slack_ind[0].shape, device=device, dtype=dtyp, requires_grad=False))) rand_b = torch.rand(ind_b_push_tent.shape[0], device=device) ind_a_ind = ind_a_tent_lt_inclusive + ((ind_a_tent_rt_exclusive - ind_a_tent_lt_inclusive)*rand_b).to(dtyp) ind_a = ind_zero_slack_ind[1][ind_a_ind] #tentative A push (maybe release) ind_a_push, ind_b_push_ind = unique2(ind_a, input_sorted=False) #find exact a to push, and corresponding index ind_b_push = ind_b_push_tent[ind_b_push_ind] #final type b vertex to push edge_push = (ind_b_push, ind_a_push) ind_b_no_push = ind_b_free[0][(ind_b_free[0][:, None] != ind_b_push_tent).all(dim=1)] # calculate the flow that push to free copies of A push_flow_free, unsatisfied_vertices_ind = torch.min(torch.vstack((FreeB[ind_b_push], FreeA[ind_a_push])),0) # find push release edges B->A and corresponding vertices A ind_release_ind = torch.where(FreeA[ind_a_push] == 0) ind_b_push_release = ind_b_push[ind_release_ind] ind_a_push_release = ind_a_push[ind_release_ind] edge_push_released = (ind_b_push_release, ind_a_push_release) ind_a_release = ind_a_push_release # find release edges A->B release_edge_tent_ind = torch.where(torch.t(yFA[:,ind_a_release])==yA[ind_a_release][:,None]) # type a sorted, index 0 ind_a_release, ind_b_release_tent_lt_inclusive = unique2(ind_a_release[release_edge_tent_ind[0]], input_sorted=True) ind_b_release_tent_rt_exclusive = torch.cat((ind_b_release_tent_lt_inclusive[1:], torch.tensor(release_edge_tent_ind[1].shape, device=device, dtype=dtyp, requires_grad=False))) rand_a = torch.rand(ind_a_release.shape[0], device=device) ind_b_ind = ind_b_release_tent_lt_inclusive + ((ind_b_release_tent_rt_exclusive - ind_b_release_tent_lt_inclusive)*rand_a).to(dtyp) ind_b_release = release_edge_tent_ind[1][ind_b_ind] edge_release = (ind_b_release, ind_a_release) FreeB[ind_b_push] -= push_flow_free # calculate the release flow and find the full released edges A->B push_flow_release, part_release_ind = torch.min(torch.vstack((F[edge_release], FreeB[ind_b_push_release])),0) part_release_ind = part_release_ind.to(bool) ind_full_release_ind = ~part_release_ind ind_a_full_release = ind_a_release[ind_full_release_ind] ind_b_full_release = ind_b_release[ind_full_release_ind] edge_full_release = (ind_b_full_release, ind_a_full_release) # update transport configuration (flow/slack/dual weight) # update flow and free vertiecs f -= torch.sum(push_flow_free) ff += torch.sum(push_flow_free).item() F[edge_push] += push_flow_free F[edge_push_released] += push_flow_release F[edge_release] -= push_flow_release FreeA[ind_a_push] -= push_flow_free FreeB[ind_b_push_release] -= push_flow_release FreeB.index_add_(0, ind_b_release, push_flow_release) # dual weight/slack # update dual weight and slack of type b vertices that not able be pushed at current iteration yB[ind_b_no_push] += one S[ind_b_no_push, :] -= one b_no_push_edge_w_flow = torch.where(F[ind_b_no_push,:]!=0) b_no_push_edge_w_flow = (ind_b_no_push[b_no_push_edge_w_flow[0]], b_no_push_edge_w_flow[1]) yFA[b_no_push_edge_w_flow] = yA[b_no_push_edge_w_flow[1]] - one # update edge-wise dual weight of type A vertices that are pushed/released at current iteration yFA[edge_push] = yA[ind_a_push] - one # pushed edge yFA = yA - 1 yFA[edge_full_release] = one # set dual weight yFA to default value if an edge is fully released # update dual weight of type a vertices that is exhausted at current iteration ind_a_exhausted_tent = torch.unique(torch.cat((ind_a_push, b_no_push_edge_w_flow[1]))) ind_a_push_not_free = ind_a_exhausted_tent[FreeA[ind_a_exhausted_tent] == 0] yFA_mask = (yA[ind_a_push_not_free] > yFA[:,ind_a_push_not_free]) yFA_mask[yFA[:,ind_a_push_not_free] == one] = True ind_a_exhausted_ind = yFA_mask.all(dim=0) ind_a_exhausted = torch.masked_select(ind_a_push_not_free, ind_a_exhausted_ind) yA[ind_a_exhausted] -= one S[:,ind_a_exhausted] += one iteration += 1 # reverse scaling scaling_error_A = FreeA_ori - FreeA_ scaling_error_B = FreeB_ - FreeB_ori F = F/alpha ind_a_all_transported_after_scaling = torch.where(FreeA==0) FreeA = (FreeA - scaling_error_A)/alpha FreeA[ind_a_all_transported_after_scaling] = 0 FreeB = (FreeB + scaling_error_B)/alpha reverse_edges_ = torch.where(torch.t(F[:,ind_a_all_transported_after_scaling[0]])!=0) _, reverse_edges_B_ind = unique2(reverse_edges_[0], input_sorted=True) reverse_edges = (reverse_edges_[1][reverse_edges_B_ind], ind_a_all_transported_after_scaling[0]) reverse_flow = scaling_error_A[reverse_edges[1]]/alpha F[reverse_edges] -= reverse_flow FreeB.index_add_(0, reverse_edges[0], reverse_flow) f_left = torch.sum(FreeB) while f_left > 1e-12: ind_b_left = torch.where(FreeB > 0)[0] ind_a_left = torch.where(FreeA > 0)[0] rand_b = torch.rand(ind_b_left.shape[0], device=device) ind_a_left_push_ind = (rand_b * ind_a_left.shape[0]).to(dtyp).long() ind_a_left_push = ind_a_left[ind_a_left_push_ind] F_ = torch.zeros(F.shape, dtype=torch.float64, device=device) F_[(ind_b_left, ind_a_left_push)] = torch.min(FreeB[:,None], FreeA[None,:])[(ind_b_left, ind_a_left_push)] ind_a_left_push_unique = torch.unique(ind_a_left_push) F_=F_[ind_b_left,:][:,ind_a_left_push_unique] F_, _ = subset_sum_filter(F_, FreeA[ind_a_left_push_unique], dim = 0) edge_push = torch.where(F_ > 0) push_flow_left = F_[edge_push] edge_push = (ind_b_left[edge_push[0]], ind_a_left_push_unique[edge_push[1]]) F[edge_push] += push_flow_left FreeB[edge_push[0]] -= push_flow_left FreeA.index_add_(0,edge_push[1],m_one*push_flow_left) f_left -= torch.sum(push_flow_left) # flow_validation_final(F, FreeA, DA, FreeB, SB, device) total_cost = torch.sum(F*C) return F, yA, yB, total_cost