dimensiereductie_courtemanche.py 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #!/bin/env python3
  2. import ithildin as ith
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import sys
  6. from typing import List
  7. from scipy.interpolate import RectBivariateSpline
  8. from pydiffmap.diffusion_map import DiffusionMap
  9. from pydiffmap import visualization as diff_visualization
  10. import matplotlib
  11. # Courtemanche1998 model, 1 ruimtedimensie, 20 variabelen
  12. data = ith.SimData.from_stem("myokit12/myokit_1")
  13. # Vermijd randeffecten en initialiatie
  14. def snij_randen_weg(variabele,begintijd, eindtijd):
  15. def begin(index):
  16. if index == 0:
  17. return begintijd
  18. return variabele.shape[index]//8
  19. def eind(index):
  20. if index == 0:
  21. return eindtijd
  22. if variabele.shape[index] < 20:
  23. return variabele.shape[index]
  24. return (7 * variabele.shape[index])//8
  25. return variabele[begin(0):eind(0),begin(1):eind(1),begin(2):eind(2),begin(3):eind(3)]
  26. def snij_randen_weg2(variables,begintijd,eindtijd):
  27. newdata = dict()
  28. for key in variables.keys():
  29. newdata[key] = snij_randen_weg(variables[key], begintijd,eindtijd)
  30. return newdata
  31. def onderzoek(t0, t1, aantal_punten, eps):
  32. vars = snij_randen_weg2(data.vars, t0, t1)
  33. # We bekijken de faseruimte, tijds- en ruimtecoordinaten zijn daarin niet van belang.
  34. for key in vars.keys():
  35. vars[key] = vars[key].ravel()
  36. # Verlaag aantal punten om geheugengebruik en tijdsduur te beperken.
  37. # replace=True strikt genomen incorrect maar verlaagt geheugengebruik
  38. # en wegens de hoge hoeveelheid punten weinig belangrijk
  39. np.random.seed(1)
  40. keuzes = np.random.choice(vars['u'].shape[0], aantal_punten,replace=True)
  41. for key in data.vars.keys():
  42. vars[key] = vars[key][keuzes]
  43. # Nu de randen verwijderd zijn en we punten in een faseruimte hebben, kunnen we proberen
  44. # diffusion maps te gebruiken.
  45. dmap = DiffusionMap.from_sklearn(epsilon=eps,n_evecs=4)
  46. dmap.fit(np.column_stack((vars['gateui'], vars['gatexs'], vars['gated'], vars['Ca'], vars['gatew'], vars['gateu'], vars['Na'], vars['gateoa'], vars['Caup'], vars['gatef'], vars['gateua'], vars['gateh'], vars['gatexr'], vars['K'], vars['gatefCa'], vars['u'], vars['gatem'], vars['gatev'], vars['gatej'], vars['gateoi'], vars['Carel'])))
  47. diff_visualization.embedding_plot(dmap,dim=3)
  48. return dmap
  49. # Zoals ‘Embedding given by first three DCs, coloured by fourth (BOFC model, %s points)’,
  50. # maar plot ook de punten die niet gebruikt worden om de diffusion map op te stellen.
  51. # Hopelijk maakt dat de figuur wat duidelijker ...
  52. def plot_veel_punten(t0, t1, aantal_punten, dmap):
  53. vars = snij_randen_weg2(data.vars, t0, t1)
  54. # copied from 'onderzoek'
  55. for key in vars.keys():
  56. vars[key] = vars[key].ravel()
  57. print("raveled")
  58. np.random.seed(1) # determinism
  59. keuzes = np.random.choice(vars['u'].shape[0], aantal_punten, replace=True)
  60. # calculate the diffusion coordinates # TODO waarom vind Emacs+Python de o-trema niet goed?
  61. var_phi1234 = dmap.transform(np.column_stack((vars['gateui'], vars['gatexs'], vars['gated'], vars['Ca'], vars['gatew'], vars['gateu'], vars['Na'], vars['gateoa'], vars['Caup'], vars['gatef'], vars['gateua'], vars['gateh'], vars['gatexr'], vars['K'], vars['gatefCa'], vars['u'], vars['gatem'], vars['gatev'], vars['gatej'], vars['gateoi'], vars['Carel'])))
  62. print("transformed")
  63. # plot it!
  64. fig = plt.figure(figsize=(6,6))
  65. ax = fig.add_subplot(111,projection='3d')
  66. ax.scatter(var_phi1234[:,0],var_phi1234[:,1],var_phi1234[:,2],var_phi1234[:,3],c=var_phi1234[:,3],cmap='viridis')
  67. ax.set_title("Embedding given by first three DCs, coloured by fourth (Courtemanche model, 1D)") # todo engels
  68. ax.set_xlabel(r'$\psi_1$')
  69. ax.set_ylabel(r'$\psi_2$')
  70. ax.set_zlabel(r'$\psi_3$')
  71. plt.axis('tight') # ? copied from pydiffmap
  72. fig.savefig("Courtemanche1D 3+1 eigenvectors, extended embedding (uvw, s).png")
  73. #fig.savefig("BOFC 3+1 eigenvectors, extended embedding (uvw, s).pdf")
  74. plt.show()
  75. plt.close()
  76. # vanaf t < 50 komen er losse puntjes (misschien te ...)
  77. matplotlib.use("TkAgg")
  78. dmap = onderzoek(50, 151, 7000, 0.2)
  79. plot_veel_punten(50, 151, 9, dmap)
  80. #del var1, var2, var3, var4 # geheugengebruik beperken
  81. #dmap = DiffusionMap.from_sklearn(n_evecs = 4)
  82. #map.fit(faseruimte)