import numpy as np
import matplotlib.pyplot as plt

### Testdaten ##################################################################

A1 = np.array([
    [  4, -2,  6 ],
    [ -2,  5, -1 ],
    [  6, -1, 26 ]])
b1 = np.array([ 18., 5., 82. ])
# Loesung: x[0]=1, x[1]=2, x[2]=3

A2 = np.array([
    [ 1,  2,  3,  4 ],
    [ 2,  5,  8, 11 ],
    [ 3,  8, 14, 20 ],
    [ 4, 11, 20, 30 ]])
b2 = np.array([ 20., 50., 84., 119. ])
# Loesung: x[0]=4, x[1]=3, x[2]=2, x[3]=1


A3 = np.array([
    [ 14.00,   2.00,   1.00,   0.00,   5.00,   5.00 ],
    [  2.00,  21.00,   2.00,  -1.00,  -7.00,  -8.00 ],
    [  1.00,   2.00,  18.00,   4.00,   7.00,   3.00 ],
    [  0.00,  -1.00,   4.00,  17.00,  -4.00,   7.00 ],
    [  5.00,  -7.00,   7.00,  -4.00,  27.00,   3.00 ],
    [  5.00,  -8.00,   3.00,   7.00,   3.00,  27.00 ]])
b3 = np.array([ 27., 9., 35., 23., 31., 37. ])
# Loesung: x[0]=x[1]=x[2]=x[3]=x[4]=x[5]=1

x = np.array([
7.61140777, 2.20109759, 8.473945,   7.9622403,  3.71569903, 7.38779576,
3.69423995, 8.78176216, 9.75840765, 2.47716811, 9.60433195, 5.86239345,
4.76454684, 4.62151257, 9.34374595, 1.88478971, 3.90655169, 0.48407258,
5.90897446, 7.04775339, 1.11991606, 1.82348696, 0.84422047, 5.36019023,
3.36669157, 6.19707563, 8.20458692, 1.45761093, 1.62116771, 6.70426269,
3.28125517, 8.8490541,  5.51356509, 7.73201112, 0.19289779, 8.92472744,
4.19224485, 7.98578066, 5.56160055, 4.96955235, 3.96283061, 2.43605089,
9.70986683, 0.18785886, 1.52139034, 0.75487988, 4.40224127, 9.77317488,
7.99285738, 4.86498488, 0.47211344, 3.79034615, 7.21912461, 1.18856846,
8.88552759, 0.09202936, 5.55129456, 8.13473149, 2.75762673, 6.45076597,
0.07820677, 3.86539247, 4.78198412, 4.12499408, 5.51493086, 5.44047311,
5.45079386, 5.02701465, 2.76154843, 1.45277656, 7.99921608, 8.13391689,
0.3046652,  2.68948383, 7.80507823, 0.92339673, 7.42860453, 8.58662785,
5.52412425, 0.1852239,  6.53463929, 6.68181489, 2.79980494, 7.92474694,
9.91247826, 8.83802389, 4.78746683, 0.27127638, 7.41657859, 4.3834377,
8.33655461, 3.94148332, 0.850049,   0.60700168, 3.74630722, 5.08879511,
4.3424114,  9.74492076, 1.65109393, 7.31989442])

y = np.array([
56.4727578,  41.35726843, 57.39035574, 62.62986224, 52.20771439, 61.53750913,
50.04515502, 53.44452148, 55.87912565, 39.63049667, 53.30591679, 57.65954763,
55.72485158, 54.56257799, 59.11674928, 37.953877,   52.80602281, 20.23286717,
54.26917125, 56.9105931,  25.82142356, 29.7034276,  18.3451838,  58.22850602,
43.17021398, 63.07431915, 57.55090583, 31.94600123, 32.69664343, 60.75009243,
48.24610024, 52.80492933, 58.01184934, 60.30299395, 15.54641592, 53.41707347,
48.0705824,  62.85720943, 56.9227596,  60.72987077, 54.13600575, 38.18800369,
52.15014853, 10.24132353, 32.90557741, 18.67311204, 58.00766636, 53.64894993,
56.90254205, 52.65438046, 13.90547046, 46.35637745, 60.13661996, 25.74842184,
54.8820216,  10.23763224, 53.21679804, 62.90266011, 40.43304088, 55.26838054,
 8.75392988, 49.64932225, 58.61322146, 51.37082173, 59.86744347, 55.73011622,
57.19225078, 57.26652723, 45.54804956, 30.06454571, 55.54620831, 54.33788305,
14.55478431, 40.58104794, 61.87279834, 21.2149382,  63.62968802, 59.37698581,
55.37713749, 13.21886682, 59.47688241, 59.65925422, 43.78937539, 59.7055115,
48.46603811, 54.17291754, 50.87688383, 12.62596824, 57.88991956, 49.95665097,
57.56191917, 49.52531773, 17.32964482, 17.84998854, 46.94568906, 53.91080838,
50.51409764, 47.66378772, 32.16199637, 58.5622816])

### Funktionen von Aufgabenblatt 3 #############################################

def forwardSubstitution(L,b):
    n = len(L)
    y = np.zeros(n)
    y[0] = b[0] / L[0][0]
    for i in range(1,n):
        y[i] = b[i]
        for j in range(i):
            y[i] = y[i] - L[i][j]*y[j]
        y[i] = y[i] / L[i][i]
    return y

def backwardSubstitution(R, y):
    n = len(R)
    x = np.zeros(n)
    x[n-1] = y[n-1]/R[n-1][n-1]
    for i in range(n-2, -1, -1):
        x[i] = y[i]
        for j in range(i+1,n):
            x[i] = x[i] - R[i][j]*x[j]
        x[i] = x[i] / R[i][i]
    return x

### Aufgabe 1 ##################################################################

def choleskyFactorization(A):
    n = len(A)
    L = np.zeros((n,n))

    for i in range(n):
        for j in range(i):
            # L_i,j = (a_i,j - sum_{k=1}^{j-1} l_j,k l_i,k) / l_j,j
            L[i][j] = (A[i][j] - np.sum([L[j][k] * L[i][k] for k in range(j)])) / L[j][j]
	# L_i,i = sqrt( a_i,i - sum_{k=1}^{i-1} l_i,k^2})
        L[i][i] = np.sqrt(A[i][i] - np.sum([L[i][k] ** 2 for k in range(i)]))
    return L

### Aufgabe 2 ##################################################################

def solveLESwithCholesky(A, b):				# Ax = b <=> L L^t x = b
    L = choleskyFactorization(A)			# A = L L^t
    y = forwardSubstitution(L, b)			# loese Ly = b
    x = backwardSubstitution(np.transpose(L), y)	# loese L^t x = y
    return x

### Aufgabe 3 ##################################################################

L = choleskyFactorization(A1)
print(L)
print(np.dot(L,np.transpose(L)))
print(solveLESwithCholesky(A1, b1))

L = choleskyFactorization(A2)
print(L)
print(np.dot(L,np.transpose(L)))
print(solveLESwithCholesky(A2, b2))

L = choleskyFactorization(A3)
print(L)
print(np.dot(L,np.transpose(L)))
print(solveLESwithCholesky(A3, b3))

def quadraticReg(x, y):
    n = len(x)  # Anzahl Datenpunkte

    # Matrix für das lineare Gleichungssystem
    A = np.array([ [ np.sum(x ** 4), np.sum(x ** 3), np.sum(x ** 2) ],
                   [ np.sum(x ** 3), np.sum(x ** 2),      np.sum(x) ],
                   [ np.sum(x ** 2),      np.sum(x),              n ] ])

    # rechte Seite des linearen Gleichungssystems
    d = np.array([ np.sum((x ** 2) * y), np.sum(x*y), np.sum(y) ])

    return solveLESwithCholesky(A, d)

a, b, c = quadraticReg(x, y)
print(f"a = {a}, b = {b}, c = {c}")

p = lambda x: a * (x ** 2) + b * x + c  # Regressionspolynom

#Daten und Regressionspolynom plotten
xmin = np.min(x)                  # linker Rand vom Plot des Regressionspolynoms
xmax = np.max(x)                  # rechter Rand
xx = np.linspace(xmin, xmax, 100) # Stützstellen zum Plotten des Regressionspolynoms
yy = p(xx)                        # zugehörige Funktionswerte

plt.scatter(x, y, alpha=0.5, label='Daten')
plt.plot(xx, yy, color='red', label='Regressionspolynom')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('Daten mit Regressionspolynom')
plt.show()
