from sympy.core.singleton import S
from sympy.core.sympify import sympify
from sympy.sets.sets import (EmptySet, FiniteSet, Intersection,
    Interval, ProductSet, Set, Union, UniversalSet)
from sympy.sets.fancysets import (ComplexRegion, Naturals, Naturals0,
    Integers, Rationals, Reals)
from sympy.multipledispatch import Dispatcher


union_sets = Dispatcher('union_sets')


@union_sets.register(Naturals0, Naturals)
def _(a, b):
    return a

@union_sets.register(Rationals, Naturals)
def _(a, b):
    return a

@union_sets.register(Rationals, Naturals0)
def _(a, b):
    return a

@union_sets.register(Reals, Naturals)
def _(a, b):
    return a

@union_sets.register(Reals, Naturals0)
def _(a, b):
    return a

@union_sets.register(Reals, Rationals)
def _(a, b):
    return a

@union_sets.register(Integers, Set)
def _(a, b):
    intersect = Intersection(a, b)
    if intersect == a:
        return b
    elif intersect == b:
        return a

@union_sets.register(ComplexRegion, Set)
def _(a, b):
    if b.is_subset(S.Reals):
        # treat a subset of reals as a complex region
        b = ComplexRegion.from_real(b)

    if b.is_ComplexRegion:
        # a in rectangular form
        if (not a.polar) and (not b.polar):
            return ComplexRegion(Union(a.sets, b.sets))
        # a in polar form
        elif a.polar and b.polar:
            return ComplexRegion(Union(a.sets, b.sets), polar=True)
    return None

@union_sets.register(EmptySet, Set)
def _(a, b):
    return b


@union_sets.register(UniversalSet, Set)
def _(a, b):
    return a

@union_sets.register(ProductSet, ProductSet)
def _(a, b):
    if b.is_subset(a):
        return a
    if len(b.sets) != len(a.sets):
        return None
    if len(a.sets) == 2:
        a1, a2 = a.sets
        b1, b2 = b.sets
        if a1 == b1:
            return a1 * Union(a2, b2)
        if a2 == b2:
            return Union(a1, b1) * a2
    return None

@union_sets.register(ProductSet, Set)
def _(a, b):
    if b.is_subset(a):
        return a
    return None

@union_sets.register(Interval, Interval)
def _(a, b):
    if a._is_comparable(b):
        from sympy.functions.elementary.miscellaneous import Min, Max
        # Non-overlapping intervals
        end = Min(a.end, b.end)
        start = Max(a.start, b.start)
        if (end < start or
           (end == start and (end not in a and end not in b))):
            return None
        else:
            start = Min(a.start, b.start)
            end = Max(a.end, b.end)

            left_open = ((a.start != start or a.left_open) and
                         (b.start != start or b.left_open))
            right_open = ((a.end != end or a.right_open) and
                          (b.end != end or b.right_open))
            return Interval(start, end, left_open, right_open)

@union_sets.register(Interval, UniversalSet)
def _(a, b):
    return S.UniversalSet

@union_sets.register(Interval, Set)
def _(a, b):
    # If I have open end points and these endpoints are contained in b
    # But only in case, when endpoints are finite. Because
    # interval does not contain oo or -oo.
    open_left_in_b_and_finite = (a.left_open and
                                     sympify(b.contains(a.start)) is S.true and
                                     a.start.is_finite)
    open_right_in_b_and_finite = (a.right_open and
                                      sympify(b.contains(a.end)) is S.true and
                                      a.end.is_finite)
    if open_left_in_b_and_finite or open_right_in_b_and_finite:
        # Fill in my end points and return
        open_left = a.left_open and a.start not in b
        open_right = a.right_open and a.end not in b
        new_a = Interval(a.start, a.end, open_left, open_right)
        return {new_a, b}
    return None

@union_sets.register(FiniteSet, FiniteSet)
def _(a, b):
    return FiniteSet(*(a._elements | b._elements))

@union_sets.register(FiniteSet, Set)
def _(a, b):
    # If `b` set contains one of my elements, remove it from `a`
    if any(b.contains(x) == True for x in a):
        return {
            FiniteSet(*[x for x in a if b.contains(x) != True]), b}
    return None

@union_sets.register(Set, Set)
def _(a, b):
    return None
