import sympy as sym

def test():
    x, y, L = sym.symbols('x, y, L')
    f = x**2 + y
    g = x-y

    def finding_equations(f, g):
        eqs = []
        for symbol in f.atoms(sym.Symbol):  # f.atoms(Symbol) picks out the symbols that represent our variables.
            eqs.append(sym.diff(f, symbol) - L * sym.diff(g, symbol))

        eqs.append(g)

        return eqs

    Result = sym.nsolve(finding_equations(f, g), [x, y, L], [1,1,1])
    print(Result)

def oppg1():
    N = 6
    p_list = [sym.Symbol('p'+str(i)) for i in range(N)]

    L = sym.Symbol('L1')

    f = sum([p * sym.log(p) for p in p_list])
    g = sum(p_list) - 1

    def solver(f, g):
        eqs = []
        for var in f.atoms(sym.Symbol):
            eqs.append(sym.diff(f, var) - L * sym.diff(g, var))

        eqs.append(g)
        return eqs

    res = sym.nsolve(solver(f,g), p_list+[L], [0.5 for x in range(len(p_list) +1)])
    print(res) #uniform sannsynlighetsfordeling p_i = 1/N

oppg1()

