dual.hh 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file dual.hh
  3. /// @brief Dual numbers for automatic differentiation.
  4. // (c) Daniel Llorens - 2013, 2015
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. // I used VanderBergen2012, Berland2006. Generally about automatic differentiation:
  10. // http://en.wikipedia.org/wiki/Automatic_differentiation
  11. // From the Taylor expansion of f(a) or f(a, b)...
  12. // f(a+εa') = f(a)+εa'f_a(a)
  13. // f(a+εa', b+εb') = f(a, b)+ε[a'f_a(a, b) b'f_b(a, b)]
  14. #pragma once
  15. #include <iosfwd>
  16. #include "ra/macros.hh"
  17. #include <cmath>
  18. using std::abs, std::sqrt, std::fma;
  19. template <class T>
  20. struct Dual
  21. {
  22. T re, du;
  23. // FIXME requires bug in gcc 10.1
  24. template <class S, class Enable=void>
  25. struct real_part { struct type {}; };
  26. template <class S>
  27. struct real_part<S, std::enable_if_t<!(std::is_same_v<S, std::decay_t<decltype(std::declval<S>().real())>>)>>
  28. {
  29. using type = typename S::value_type;
  30. };
  31. Dual(T const & r, T const & d): re(r), du(d) {}
  32. Dual(T const & r): re(r), du(0.) {} // conversions are by default constants.
  33. Dual(typename real_part<T>::type const & r): re(r), du(0.) {}
  34. Dual() {}
  35. #define DEF_ASSIGN_OPS(OP) \
  36. Dual & operator JOIN(OP, =)(T const & r) { *this = *this OP r; return *this; } \
  37. Dual & operator JOIN(OP, =)(Dual const & r) { *this = *this OP r; return *this; } \
  38. Dual & operator JOIN(OP, =)(typename real_part<T>::type const & r) { *this = *this OP r; return *this; }
  39. FOR_EACH(DEF_ASSIGN_OPS, +, -, /, *)
  40. #undef DEF_ASSIGN_OPS
  41. };
  42. // conversions are by default constants.
  43. template <class R> auto dual(Dual<R> const & r) { return r; }
  44. template <class R> auto dual(R const & r) { return Dual<R> { r, 0. }; }
  45. template <class R, class D>
  46. auto dual(R const & r, D const & d)
  47. {
  48. return Dual<std::common_type_t<R, D>> { r, d };
  49. }
  50. template <class A, class B>
  51. auto operator*(Dual<A> const & a, Dual<B> const & b)
  52. {
  53. return dual(a.re*b.re, a.re*b.du + a.du*b.re);
  54. }
  55. template <class A, class B>
  56. auto operator*(A const & a, Dual<B> const & b)
  57. {
  58. return dual(a*b.re, a*b.du);
  59. }
  60. template <class A, class B>
  61. auto operator*(Dual<A> const & a, B const & b)
  62. {
  63. return dual(a.re*b, a.du*b);
  64. }
  65. template <class A, class B, class C>
  66. auto fma(Dual<A> const & a, Dual<B> const & b, Dual<C> const & c)
  67. {
  68. return dual(fma(a.re, b.re, c.re), fma(a.re, b.du, fma(a.du, b.re, c.du)));
  69. }
  70. template <class A, class B>
  71. auto operator+(Dual<A> const & a, Dual<B> const & b)
  72. {
  73. return dual(a.re+b.re, a.du+b.du);
  74. }
  75. template <class A, class B>
  76. auto operator+(A const & a, Dual<B> const & b)
  77. {
  78. return dual(a+b.re, b.du);
  79. }
  80. template <class A, class B>
  81. auto operator+(Dual<A> const & a, B const & b)
  82. {
  83. return dual(a.re+b, a.du);
  84. }
  85. template <class A, class B>
  86. auto operator-(Dual<A> const & a, Dual<B> const & b)
  87. {
  88. return dual(a.re-b.re, a.du-b.du);
  89. }
  90. template <class A, class B>
  91. auto operator-(Dual<A> const & a, B const & b)
  92. {
  93. return dual(a.re-b, a.du);
  94. }
  95. template <class A, class B>
  96. auto operator-(A const & a, Dual<B> const & b)
  97. {
  98. return dual(a-b.re, -b.du);
  99. }
  100. template <class A>
  101. auto operator-(Dual<A> const & a)
  102. {
  103. return dual(-a.re, -a.du);
  104. }
  105. template <class A>
  106. decltype(auto) operator+(Dual<A> const & a)
  107. {
  108. return a;
  109. }
  110. template <class A>
  111. auto inv(Dual<A> const & a)
  112. {
  113. auto i = 1./a.re;
  114. return dual(i, -a.du*(i*i));
  115. }
  116. template <class A, class B>
  117. auto operator/(Dual<A> const & a, Dual<B> const & b)
  118. {
  119. return a*inv(b);
  120. }
  121. template <class A, class B>
  122. auto operator/(Dual<A> const & a, B const & b)
  123. {
  124. return a*inv(dual(b));
  125. }
  126. template <class A, class B>
  127. auto operator/(A const & a, Dual<B> const & b)
  128. {
  129. return dual(a)*inv(b);
  130. }
  131. template <class A>
  132. auto cos(Dual<A> const & a)
  133. {
  134. return dual(cos(a.re), -sin(a.re)*a.du);
  135. }
  136. template <class A>
  137. auto sin(Dual<A> const & a)
  138. {
  139. return dual(sin(a.re), +cos(a.re)*a.du);
  140. }
  141. template <class A>
  142. auto cosh(Dual<A> const & a)
  143. {
  144. return dual(cosh(a.re), +sinh(a.re)*a.du);
  145. }
  146. template <class A>
  147. auto sinh(Dual<A> const & a)
  148. {
  149. return dual(sinh(a.re), +cosh(a.re)*a.du);
  150. }
  151. template <class A>
  152. auto tan(Dual<A> const & a)
  153. {
  154. auto c = cos(a.du);
  155. return dual(tan(a.re), a.du/(c*c));
  156. }
  157. template <class A>
  158. auto exp(Dual<A> const & a)
  159. {
  160. return dual(exp(a.re), +exp(a.re)*a.du);
  161. }
  162. template <class A, class B>
  163. auto pow(Dual<A> const & a, B const & b)
  164. {
  165. return dual(pow(a.re, b), +b*pow(a.re, b-1)*a.du);
  166. }
  167. template <class A>
  168. auto log(Dual<A> const & a)
  169. {
  170. return dual(log(a.re), +a.du/a.re);
  171. }
  172. template <class A>
  173. auto sqrt(Dual<A> const & a)
  174. {
  175. return dual(sqrt(a.re), +a.du/(2.*sqrt(a.re)));
  176. }
  177. template <class A>
  178. auto sqr(Dual<A> const & a)
  179. {
  180. return a*a;
  181. }
  182. template <class A>
  183. auto abs(Dual<A> const & a)
  184. {
  185. return abs(a.re);
  186. }
  187. template <class A>
  188. bool isfinite(Dual<A> const & a)
  189. {
  190. return isfinite(a.re) && isfinite(a.du);
  191. }
  192. template <class A>
  193. auto xI(Dual<A> const & a)
  194. {
  195. return dual(xI(a.re), xI(a.du));
  196. }
  197. template <class A>
  198. std::ostream & operator<<(std::ostream & o, Dual<A> const & a)
  199. {
  200. return o << "[" << a.re << " " << a.du << "]";
  201. }
  202. template <class A>
  203. std::istream & operator>>(std::istream & i, Dual<A> & a)
  204. {
  205. std::string s;
  206. i >> s;
  207. if (s!="[") {
  208. i.setstate(std::ios::failbit);
  209. return i;
  210. }
  211. a >> a.re;
  212. a >> a.du;
  213. i >> s;
  214. if (s!="]") {
  215. i.setstate(std::ios::failbit);
  216. return i;
  217. }
  218. }