dual.H 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. // (c) Daniel Llorens - 2013, 2015
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file dual.H
  7. /// @brief Dual numbers.
  8. // I used VanderBergen2012, Berland2006. Generally about automatic differentiation:
  9. // http://en.wikipedia.org/wiki/Automatic_differentiation
  10. // From the Taylor expansion of f(a) or f(a, b)...
  11. // f(a+εa') = f(a)+εa'f_a(a)
  12. // f(a+εa', b+εb') = f(a, b)+ε[a'f_a(a, b) b'f_b(a, b)]
  13. #pragma once
  14. #include <iosfwd>
  15. #include "ra/macros.H"
  16. #include <complex>
  17. using std::abs;
  18. template <class T>
  19. struct Dual
  20. {
  21. T re, du;
  22. Dual(T const & r, T const & d): re(r), du(d) {}
  23. Dual(T const & r): re(r), du(0.) {} // conversions are by default constants.
  24. Dual() {}
  25. #define DEF_ASSIGN_OPS(OP) \
  26. Dual & operator JOIN(OP, =)(T const & r) { *this = *this OP r; return *this; } \
  27. Dual & operator JOIN(OP, =)(Dual const & r) { *this = *this OP r; return *this; }
  28. FOR_EACH(DEF_ASSIGN_OPS, +, -, /, *)
  29. #undef DEF_ASSIGN_OPS
  30. };
  31. template <class T>
  32. struct Dual<std::complex<T>>
  33. {
  34. using C = std::complex<T>;
  35. C re, du;
  36. Dual(C const & r, C const & d): re(r), du(d) {}
  37. Dual(C const & r): re(r), du(0.) {} // conversions are by default constants.
  38. Dual(T const & r): re(r), du(0.) {} // TODO repeat the class just for this?
  39. Dual() {}
  40. #define DEF_ASSIGN_OPS(OP) \
  41. Dual & operator JOIN(OP, =)(T const & r) { *this = *this OP r; return *this; } \
  42. Dual & operator JOIN(OP, =)(C const & r) { *this = *this OP r; return *this; } \
  43. Dual & operator JOIN(OP, =)(Dual const & r) { *this = *this OP r; return *this; }
  44. FOR_EACH(DEF_ASSIGN_OPS, +, -, /, *)
  45. #undef DEF_ASSIGN_OPS
  46. };
  47. // conversions are by default constants.
  48. template <class R>
  49. Dual<R> dual(R const & r)
  50. {
  51. return Dual<R> { r, 0. };
  52. }
  53. template <class R, class D>
  54. auto dual(R const & r, D const & d)
  55. {
  56. return Dual<decltype(r+d)> { r, d };
  57. }
  58. template <class A, class B>
  59. auto operator*(Dual<A> const & a, Dual<B> const & b)
  60. {
  61. return Dual<decltype(a.re*b.re)> { a.re*b.re, a.re*b.du + a.du*b.re };
  62. }
  63. template <class A, class B>
  64. auto operator*(A const & a, Dual<B> const & b)
  65. {
  66. return dual(a*b.re, a*b.du);
  67. }
  68. template <class A, class B>
  69. auto operator*(Dual<A> const & a, B const & b)
  70. {
  71. return dual(a.re*b, a.du*b);
  72. }
  73. template <class A, class B, class C>
  74. auto fma(Dual<A> const & a, Dual<B> const & b, Dual<C> const & c)
  75. {
  76. return dual(fma(a.re, b.re, c.re),
  77. fma(a.re, b.du, fma(a.du, b.re, c.du)));
  78. }
  79. template <class A, class B>
  80. auto operator+(Dual<A> const & a, Dual<B> const & b)
  81. {
  82. return dual(a.re+b.re, a.du+b.du);
  83. }
  84. template <class A, class B>
  85. auto operator+(A const & a, Dual<B> const & b)
  86. {
  87. return dual(a+b.re, b.du);
  88. }
  89. template <class A, class B>
  90. auto operator+(Dual<A> const & a, B const & b)
  91. {
  92. return dual(a.re+b, a.du);
  93. }
  94. template <class A, class B>
  95. auto operator-(Dual<A> const & a, Dual<B> const & b)
  96. {
  97. return dual(a.re-b.re, a.du-b.du);
  98. }
  99. template <class A, class B>
  100. auto operator-(Dual<A> const & a, B const & b)
  101. {
  102. return dual(a.re-b, a.du);
  103. }
  104. template <class A, class B>
  105. auto operator-(A const & a, Dual<B> const & b)
  106. {
  107. return dual(a-b.re, -b.du);
  108. }
  109. template <class A>
  110. auto operator-(Dual<A> const & a)
  111. {
  112. return dual(-a.re, -a.du);
  113. }
  114. template <class A>
  115. decltype(auto) operator+(Dual<A> const & a)
  116. {
  117. return a;
  118. }
  119. template <class A>
  120. auto inv(Dual<A> const & a)
  121. {
  122. auto i = 1./a.re;
  123. return dual(i, -a.du*(i*i));
  124. }
  125. template <class A, class B>
  126. auto operator/(Dual<A> const & a, Dual<B> const & b)
  127. {
  128. return a*inv(b);
  129. }
  130. template <class A, class B>
  131. auto operator/(Dual<A> const & a, B const & b)
  132. {
  133. return a*inv(dual(b));
  134. }
  135. template <class A, class B>
  136. auto operator/(A const & a, Dual<B> const & b)
  137. {
  138. return dual(a)*inv(b);
  139. }
  140. template <class A>
  141. auto cos(Dual<A> const & a)
  142. {
  143. return dual(cos(a.re), -sin(a.re)*a.du);
  144. }
  145. template <class A>
  146. auto sin(Dual<A> const & a)
  147. {
  148. return dual(sin(a.re), +cos(a.re)*a.du);
  149. }
  150. template <class A>
  151. auto cosh(Dual<A> const & a)
  152. {
  153. return dual(cosh(a.re), +sinh(a.re)*a.du);
  154. }
  155. template <class A>
  156. auto sinh(Dual<A> const & a)
  157. {
  158. return dual(sinh(a.re), +cosh(a.re)*a.du);
  159. }
  160. template <class A>
  161. auto tan(Dual<A> const & a)
  162. {
  163. auto c = cos(a.du);
  164. return dual(tan(a.re), a.du/(c*c));
  165. }
  166. template <class A>
  167. auto exp(Dual<A> const & a)
  168. {
  169. return dual(exp(a.re), +exp(a.re)*a.du);
  170. }
  171. template <class A, class B>
  172. auto pow(Dual<A> const & a, B const & b)
  173. {
  174. return dual(pow(a.re, b), +b*pow(a.re, b-1)*a.du);
  175. }
  176. template <class A>
  177. auto log(Dual<A> const & a)
  178. {
  179. return dual(log(a.re), +a.du/a.re);
  180. }
  181. template <class A>
  182. auto sqrt(Dual<A> const & a)
  183. {
  184. return dual(sqrt(a.re), +a.du/(2.*sqrt(a.re)));
  185. }
  186. template <class A>
  187. auto sqr(Dual<A> const & a)
  188. {
  189. return a*a;
  190. }
  191. template <class A>
  192. auto abs(Dual<A> const & a)
  193. {
  194. return abs(a.re);
  195. }
  196. template <class A>
  197. bool isfinite(Dual<A> const & a)
  198. {
  199. return isfinite(a.re) && isfinite(a.du);
  200. }
  201. template <class A>
  202. auto xI(Dual<A> const & a)
  203. {
  204. return dual(xI(a.re), xI(a.du));
  205. }
  206. template <class A>
  207. std::ostream & operator<<(std::ostream & o, Dual<A> const & a)
  208. {
  209. return o << "[" << a.re << " " << a.du << "]";
  210. }
  211. template <class A>
  212. std::istream & operator>>(std::istream & i, Dual<A> & a)
  213. {
  214. std::string s;
  215. i >> s;
  216. if (s!="[") {
  217. i.setstate(std::ios::failbit);
  218. return i;
  219. }
  220. a >> a.re;
  221. a >> a.du;
  222. i >> s;
  223. if (s!="]") {
  224. i.setstate(std::ios::failbit);
  225. return i;
  226. }
  227. }