function WassersteinDistance = WassersteinDistance_2DGrid(Sample1,Sample2,dim1_numCells,dim2_numCells,maxit,tol)

% Things to pay attention to:
% (1) if the coordinate values are less than 1, then the floor and ceil
% functions used for grid generation (lines 49-50) will return whole
% integers so the original data point can be multiplied by 10^a depending
% on their scale so that the grid min and max values work properly.

% Inputs
% Sample1 --> 2xn matrix where n denotes the sample size
% Sample2 --> 2xn matrix where n denotes the sample size

getd = @(p)path(p,path);
getd('toolbox_signal/');
getd('toolbox_general/');

flat = @(x)x(:);
Cols = @(n0,n1)sparse( flat(repmat(1:n1, [n0 1])), ...
             flat(reshape(1:n0*n1,n0,n1) ), ...
             ones(n0*n1,1) );
Rows = @(n0,n1)sparse( flat(repmat(1:n0, [n1 1])), ...
             flat(reshape(1:n0*n1,n0,n1)' ), ...
             ones(n0*n1,1) );
Sigma = @(n0,n1)[Rows(n0,n1);Cols(n0,n1)];

% We use a simplex algorithm to compute the optimal transport coupling \(\ga^\star\)
% maxit = 1e5; tol = 1e-9;
otransp  = @(C,p0,p1)reshape( perform_linprog( ...
        Sigma(length(p0),length(p1)), ...
        [p0(:);p1(:)], C(:), 0, maxit, tol), [length(p0) length(p1)] );


% Dimensions of the grid and cell dimensions
Sample1_max = max(Sample1');
Sample1_min = min(Sample1');
Sample2_max = max(Sample2');
Sample2_min = min(Sample2');
dim1_max = max(Sample1_max(1,1),Sample2_max(1,1));
dim1_min = min(Sample1_min(1,1),Sample2_min(1,1));
dim2_max = max(Sample1_max(1,2),Sample2_max(1,2));
dim2_min = min(Sample1_min(1,2),Sample2_min(1,2));

dim1_gridCellSize = (ceil(dim1_max) - floor(dim1_min)) / (dim1_numCells); % need to divide by dim1_numCells instead of (dim1_numCells-1) to get dim1_numCells cells for the grid's dim1
dim2_gridCellSize = (ceil(dim2_max) - floor(dim2_min)) / (dim2_numCells); % need to divide by dim2_numCells instead of (dim2_numCells-1) to get dim2_numCells cells for the grid's dim2

% Divide the space into a 2D grid
Xedges = floor(dim1_min):dim1_gridCellSize:ceil(dim1_max); 
Yedges = floor(dim2_min):dim2_gridCellSize:ceil(dim2_max); % need to use dim1_gridCellSize+1 to get dim1_gridCellSize+1 cells for the grid's dim2

% Probability density for Sample1 and Sample2 over the 2D grid
pdf_Sample1 = histcounts2(Sample1(1,:),Sample1(2,:),Xedges,Yedges,'Normalization','probability');
pdf_Sample2 = histcounts2(Sample2(1,:),Sample2(2,:),Xedges,Yedges,'Normalization','probability');

% Test to make sure the probabilities add up to 1
% sum(sum(pdf_Sample1)) 
% sum(sum(pdf_Sample2))

% Cell center coordinates (X0 and X1 from the point cloud example)
% and density weights (p_0 and p_1 from the point cloud example) 
% reformatted into a 2xn0 matrix
X0 = []; % Cell center coordinates with density > 0 for Sample1
X1 = []; % Cell center coordinates with density > 0 for Sample1
p0 = [];
p1 = [];
for i = 1:dim1_numCells
    for j = 1:dim2_numCells        
        % for computational purposes when we get to solving the optimal
        % transport optimization, the code only considers the histogram 
        % bins that have a probability density greater than 0.
        cellCenter = [(Xedges(i)+Xedges(i+1))/2; (Yedges(j)+Yedges(j+1))/2];
        if pdf_Sample1(i,j) > 0
            X0 = [X0,cellCenter];
            p0 = [p0; pdf_Sample1(i,j)];
        end
        if pdf_Sample2(i,j) > 0
            X1 = [X1,cellCenter];
            p1 = [p1; pdf_Sample2(i,j)];
        end

    end
end

% A common 2D grid is used for both samples, hence X1 = X0 and only 
% the two density weights differe over the 2D grid for Sample1 and Sample2
n0 = length(p0);
n1 = length(p1);

% Shortcut for display
myplot = @(x,y,ms,col)plot(x,y, 'o', 'MarkerSize', ms, 'MarkerEdgeColor', 'k', 'MarkerFaceColor', col, 'LineWidth', 2);

% % Display the point clouds. The size of each dot is proportional to its probability density weight
% clf; hold on;
% for i=1:length(p0)
%     myplot(X0(1,i), X0(2,i), p0(i)*50, 'b');
% end
% for i=1:length(p1)
%     myplot(X1(1,i), X1(2,i), p1(i)*50, 'r');
% end
% axis([dim1_min dim1_max dim2_min dim2_max]); axis off;

% Compute the weight (cost) matrix C_{i,j}_{i,j}
for i = 1:n0
    for j = 1:n1

        TwoPointsAtHand = [X0(:,i)';X1(:,j)'];
        C(i,j) = pdist(TwoPointsAtHand,'euclidean');

    end
end

% Compute the optimal transport plan.
gamma = otransp(C,p0,p1);

% This is what I think will give us the Wasserstein Distance as defined by \sum_{i,j} \ga_{i,j}^\star C_{i,j}
Full_gamma= full(gamma);
WassersteinDistance = dot(Full_gamma(:),C(:));


% Check that the number of non-zero entries in \(\ga^\star\) is \(n_0+n_1-1\).
%fprintf('Number of non-zero: %d (n0+n1-1=%d)\n', full(sum(gamma(:)~=0)), n0+n1-1);

% Check that the solution satifies the constraints \(\ga \in \Cc\).
%fprintf('Constraints deviation (should be 0): %.2e, %.2e.\n', norm(sum(gamma,2)-p0(:)),  norm(sum(gamma,1)'-p1(:)));

%fprintf('The Wasserstein Distance is: %.2e, %.2e.\n', WassersteinDistance);