# From Geometric Tools Engine
# https://www.geometrictools.com/GTEngine/Include/Mathematics/GteDistSegmentSegment.h

import numpy as np

def SegmentsClosestPoints(P0, P1, Q0, Q1):
    P0, P1, Q0, Q1 = (np.asarray(a) for a in (P0, P1, Q0, Q1))
    single = True
    if P0.ndim > 1 or P1.ndim > 1 or Q0.ndim > 1 or Q1.ndim > 1:
        single = False
    P0 = np.atleast_2d(P0)
    P0, P1, Q0, Q1 = np.broadcast_arrays(P0, P1, Q0, Q1)
    shape = P0.shape[:-1]

    # The code allows degenerate line segments; that is, P0 and P1 can be
    # the same point or Q0 and Q1 can be the same point.  The quadratic
    # function for squared distance between the segment is
    #   R(s,t) = a*s^2 - 2*b*s*t + c*t^2 + 2*d*s - 2*e*t + f
    # for (s,t) in [0,1]^2 where
    #   a = Dot(P1-P0,P1-P0), b = Dot(P1-P0,Q1-Q0), c = Dot(Q1-Q0,Q1-Q0),
    #   d = Dot(P1-P0,P0-Q0), e = Dot(Q1-Q0,P0-Q0), f = Dot(P0-Q0,P0-Q0)
    P1mP0 = P1 - P0
    Q1mQ0 = Q1 - Q0
    P0mQ0 = P0 - Q0
    mA = np.einsum('...i,...i->...', P1mP0, P1mP0)
    mB = np.einsum('...i,...i->...', P1mP0, Q1mQ0)
    mC = np.einsum('...i,...i->...', Q1mQ0, Q1mQ0)
    mD = np.einsum('...i,...i->...', P1mP0, P0mQ0)
    mE = np.einsum('...i,...i->...', Q1mQ0, P0mQ0)

    mF00 = mD
    mF10 = mF00 + mA
    mF01 = mF00 - mB
    mF11 = mF10 - mB

    mG00 = -mE
    mG10 = mG00 - mB
    mG01 = mG00 + mC
    mG11 = mG10 + mC

    parameter = np.zeros(shape + (2,))
    sValue = np.zeros(shape + (2,))
    classify = np.zeros(shape + (2,), dtype=int)
    end = np.zeros(shape + (2, 2))
    edge = np.zeros(shape + (2,), dtype=int)

    # Complex case first
    # _ComputeIntersection(sValue, classify, edge, end,
    #                      mB, mF00, mF10)
    # _ComputeMinimumParameters(edge, end, parameter,
    #                           mB, mC, mE, mG00, mG01, mG10, mG11)

    # Individual cases
    mAgt0 = mA > 0
    mCgt0 = mC > 0
    mACgt0 = mAgt0 & mCgt0
    m = mACgt0
    sValue[..., 0][m] = _GetClampedRoot(mA[m], mF00[m], mF10[m])
    sValue[..., 1][m] = _GetClampedRoot(mA[m], mF01[m], mF11[m])
    classify[m] = 0
    classify[m[..., np.newaxis] & (sValue <= 0)] = -1
    classify[m[..., np.newaxis] & (sValue >= 1)] = +1
    mClassNeg = (classify[..., 0] == -1) & (classify[..., 1] == -1)
    mClassPos = (classify[..., 0] == +1) & (classify[..., 1] == +1)
    m = mACgt0 & ~mClassNeg & ~mClassPos
    _ComputeIntersection(sValue, classify, edge, end,
                         mB, mF00, mF10)
    _ComputeMinimumParameters(edge, end, parameter,
                              mB, mC, mE, mG00, mG01, mG10, mG11)
    m = mACgt0 & mClassNeg
    parameter[..., 0][m] = 0.
    parameter[..., 1][m] = _GetClampedRoot(mC[m], mG00[m], mG01[m])
    m = mACgt0 & mClassPos
    parameter[..., 0][m] = 1.
    parameter[..., 1][m] = _GetClampedRoot(mC[m], mG10[m], mG11[m])
    m = ~mACgt0 & mAgt0
    parameter[..., 0][m] = _GetClampedRoot(mA[m], mF00[m], mF10[m])
    parameter[..., 1][m] = 0.
    m = ~mACgt0 & mCgt0
    parameter[..., 0][m] = 0.
    parameter[..., 1][m] = _GetClampedRoot(mA[m], mG00[m], mG00[m])
    m = ~mACgt0 & ~mAgt0 & ~mCgt0
    parameter[..., 0][m] = 0.
    parameter[..., 1][m] = 0.

    out1 = (1. - parameter[..., :1]) * P0 + parameter[..., :1] * P1
    out2 = (1. - parameter[..., 1:]) * Q0 + parameter[..., 1:] * Q1
    if single:
        out1 = out1[0]
        out2 = out2[0]
    return out1, out2


def _GetClampedRoot(slope, h0, h1):
    # Theoretically, r is in (0,1).  However, when the slope is nearly zero,
    # then so are h0 and h1.  Significant numerical rounding problems can
    # occur when using floating-point arithmetic.  If the rounding causes r
    # to be outside the interval, clamp it.  It is possible that r is in
    # (0,1) and has rounding errors, but because h0 and h1 are both nearly
    # zero, the quadratic is nearly constant on (0,1).  Any choice of p
    # should not cause undesirable accuracy problems for the final distance
    # computation.
    #
    # NOTE:  You can use bisection to recompute the root or even use
    # bisection to compute the root and skip the division.  This is generally
    # slower, which might be a problem for high-performance applications.

    r = -np.divide(h0, slope, out=np.zeros_like(h0), where=(slope != 0))
    mh0 = h0 < 0
    mh1 = h1 > 0
    r[~mh0] = 0.
    r[mh0 & ~mh1] = 1.
    r[mh0 & mh1 & (r > 1.)] = 0.5
    return r


def _ComputeIntersection(sValue, classify, edge, end,
                         mB, mF00, mF10):
    # The divisions are theoretically numbers in [0,1].  Numerical rounding
    # errors might cause the result to be outside the interval.  When this
    # happens, it must be that both numerator and denominator are nearly
    # zero.  The denominator is nearly zero when the segments are nearly
    # perpendicular.  The numerator is nearly zero when the P-segment is
    # nearly degenerate (mF00 = a is small).  The choice of 0.5 should not
    # cause significant accuracy problems.
    #
    # NOTE:  You can use bisection to recompute the root or even use
    # bisection to compute the root and skip the division.  This is generally
    # slower, which might be a problem for high-performance applications.

    mC0lt0 = classify[..., 0] < 0
    mC0eq0 = classify[..., 0] == 0
    mC0gt0 = classify[..., 0] > 0
    mC1lt0 = classify[..., 1] < 0
    mC1eq0 = classify[..., 1] == 0
    mC1gt0 = classify[..., 1] > 0
    mF00_mb = np.divide(mF00, mB, out=np.zeros_like(mF00), where=(mB != 0))
    mF10_mb = np.divide(mF10, mB, out=np.zeros_like(mF10), where=(mB != 0))

    m = mC0lt0
    edge[..., 0][m] = 0
    end[..., 0, 0][m] = 0.
    end[..., 0, 1][m] = mF00_mb[m]
    end[..., 0, 1][m & (end[..., 0, 1] < 0.) | (end[..., 0, 1] > 1.)] = 0.5
    m = mC0lt0 & mC1eq0
    edge[..., 1][m] = 3
    end[..., 1, 0][m] = sValue[..., 1][m]
    end[..., 1, 1][m] = 1.
    m = mC0lt0 & ~mC1eq0
    edge[..., 1][m] = 1
    end[..., 1, 0][m] = 1.
    end[..., 1, 1][m] = mF10_mb[m]
    end[..., 1, 1][m & (end[..., 1, 1] < 0.) | (end[..., 1, 1] > 1.)] = 0.5

    m = mC0eq0
    edge[..., 0][m] = 2
    end[..., 0, 0][m] = sValue[..., 0][m]
    end[..., 0, 1][m] = 0.
    m = mC0eq0 & mC1lt0
    edge[..., 1][m] = 0
    end[..., 1, 0][m] = 0.
    end[..., 1, 1][m] = mF00_mb[m]
    end[..., 1, 1][m & (end[..., 1, 1] < 0.) | (end[..., 1, 1] > 1.)] = 0.5
    m = mC0eq0 & mC1eq0
    edge[..., 1][m] = 3
    end[..., 1, 0][m] = sValue[..., 1][m]
    end[..., 1, 1][m] = 1.
    m = mC0eq0 & mC1gt0
    edge[..., 1][m] = 1
    end[..., 1, 0][m] = 1.
    end[..., 1, 1][m] = mF10_mb[m]
    end[..., 1, 1][m & (end[..., 1, 1] < 0.) | (end[..., 1, 1] > 1.)] = 0.5

    m = mC0gt0
    edge[..., 0][m] = 1
    end[..., 0, 0][m] = 1.
    end[..., 0, 1][m] = mF10_mb[m]
    end[..., 0, 1][m & (end[..., 0, 1] < 0.) | (end[..., 0, 1] > 1.)] = 0.5
    m = mC0gt0 & mC1eq0
    edge[..., 1][m] = 3
    end[..., 1, 0][m] = sValue[..., 1][m]
    end[..., 1, 1][m] = 1.
    m = mC0gt0 & ~mC1eq0
    edge[..., 1][m] = 0
    end[..., 1, 0][m] = 0.
    end[..., 1, 1][m] = mF00_mb[m]
    end[..., 1, 1][m & (end[..., 1, 1] < 0.) | (end[..., 1, 1] > 1.)] = 0.5


def _ComputeMinimumParameters(edge, end, parameter,
                              mB, mC, mE, mG00, mG01, mG10, mG11):
    mEdge0 = edge == 0
    mEdge1 = edge == 1

    delta = end[..., 1, 1] - end[..., 0, 1]
    h0 = delta * (-mB * end[..., 0, 0] + mC * end[..., 0, 1] - mE)
    mh0ge0 = h0 >= 0
    m = mh0ge0 & mEdge0[..., 0]
    parameter[..., 0][m] = 0.
    parameter[..., 1][m] = _GetClampedRoot(mC[m], mG00[m], mG01[m])
    m = mh0ge0 & mEdge1[..., 0]
    parameter[..., 0][m] = 1.
    parameter[..., 1][m] = _GetClampedRoot(mC[m], mG10[m], mG11[m])
    m = mh0ge0 & ~mEdge0[..., 0] & ~mEdge1[..., 0]
    parameter[..., 0][m] = end[..., 0, 0][m]
    parameter[..., 1][m] = end[..., 0, 1][m]

    h1 = delta * (-mB * end[..., 1, 0] + mC * end[..., 1, 1] - mE)
    mh1le0 = h1 <= 0
    m = ~mh0ge0 & mh1le0 & mEdge0[..., 1]
    parameter[..., 0][m] = 0.
    parameter[..., 1][m] = _GetClampedRoot(mC[m], mG00[m], mG01[m])
    m = ~mh0ge0 & mh1le0 & mEdge1[..., 1]
    parameter[..., 0][m] = 1.
    parameter[..., 1][m] = _GetClampedRoot(mC[m], mG10[m], mG11[m])
    m = ~mh0ge0 & mh1le0 & ~mEdge0[..., 1] & ~mEdge1[..., 1]
    parameter[..., 0][m] = end[..., 1, 0][m]
    parameter[..., 1][m] = end[..., 1, 1][m]
    m = ~mh0ge0 & ~mh1le0
    z = np.divide(h0, h0 - h1, out=np.zeros_like(h0), where=(h0 != h1))
    z = np.minimum(np.maximum(z, 0.), 1.)
    omz = 1. - z
    parameter[..., 0] = omz * end[..., 0, 0] + z * end[..., 1, 0]
    parameter[..., 1] = omz * end[..., 0, 1] + z * end[..., 1, 1]
