HairSimulationCommon.azsli 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. /*
  2. * Modifications Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. //---------------------------------------------------------------------------------------
  9. // Shader code related to simulating hair strands in compute.
  10. //-------------------------------------------------------------------------------------
  11. //
  12. // Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved.
  13. //
  14. // Permission is hereby granted, free of charge, to any person obtaining a copy
  15. // of this software and associated documentation files (the "Software"), to deal
  16. // in the Software without restriction, including without limitation the rights
  17. // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  18. // copies of the Software, and to permit persons to whom the Software is
  19. // furnished to do so, subject to the following conditions:
  20. //
  21. // The above copyright notice and this permission notice shall be included in
  22. // all copies or substantial portions of the Software.
  23. //
  24. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  25. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  26. // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  27. // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  28. // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  29. // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  30. // THE SOFTWARE.
  31. //
  32. //--------------------------------------------------------------------------------------
  33. // File: HairSimulation.azsl
  34. //
  35. // Physics simulation of hair using compute shaders
  36. //--------------------------------------------------------------------------------------
  37. #pragma once
  38. #define USE_MESH_BASED_HAIR_TRANSFORM 0
  39. // If you change the value below, you must change it in TressFXAsset.h as well.
  40. #ifndef THREAD_GROUP_SIZE
  41. #define THREAD_GROUP_SIZE 64
  42. #endif
  43. groupshared float4 sharedPos[THREAD_GROUP_SIZE];
  44. groupshared float4 sharedTangent[THREAD_GROUP_SIZE];
  45. groupshared float sharedLength[THREAD_GROUP_SIZE];
  46. //--------------------------------------------------------------------------------------
  47. //
  48. // Helper Functions for the main simulation shaders
  49. //
  50. //--------------------------------------------------------------------------------------
  51. bool IsMovable(float4 particle)
  52. {
  53. if ( particle.w > 0 )
  54. return true;
  55. return false;
  56. }
  57. float2 ConstraintMultiplier(float4 particle0, float4 particle1)
  58. {
  59. if (IsMovable(particle0))
  60. {
  61. if (IsMovable(particle1))
  62. return float2(0.5, 0.5);
  63. else
  64. return float2(1, 0);
  65. }
  66. else
  67. {
  68. if (IsMovable(particle1))
  69. return float2(0, 1);
  70. else
  71. return float2(0, 0);
  72. }
  73. }
  74. float4 MakeQuaternion(float angle_radian, float3 axis)
  75. {
  76. // create quaternion using angle and rotation axis
  77. float4 quaternion;
  78. float halfAngle = 0.5f * angle_radian;
  79. float sinHalf = sin(halfAngle);
  80. quaternion.w = cos(halfAngle);
  81. quaternion.xyz = sinHalf * axis.xyz;
  82. return quaternion;
  83. }
  84. // Makes a quaternion from a 4x4 column major rigid transform matrix. Rigid transform means that rotational 3x3 sub matrix is orthonormal.
  85. // Note that this function does not check the orthonormality.
  86. float4 MakeQuaternion(column_major float4x4 m)
  87. {
  88. float4 q;
  89. float trace = m[0][0] + m[1][1] + m[2][2];
  90. if (trace > 0.0f)
  91. {
  92. float r = sqrt(trace + 1.0f);
  93. q.w = 0.5 * r;
  94. r = 0.5 / r;
  95. q.x = (m[1][2] - m[2][1])*r;
  96. q.y = (m[2][0] - m[0][2])*r;
  97. q.z = (m[0][1] - m[1][0])*r;
  98. }
  99. else
  100. {
  101. int i = 0, j = 1, k = 2;
  102. if (m[1][1] > m[0][0])
  103. {
  104. i = 1; j = 2; k = 0;
  105. }
  106. if (m[2][2] > m[i][i])
  107. {
  108. i = 2; j = 0; k = 1;
  109. }
  110. float r = sqrt(m[i][i] - m[j][j] - m[k][k] + 1.0f);
  111. float qq[4];
  112. qq[i] = 0.5f * r;
  113. r = 0.5f / r;
  114. q.w = (m[j][k] - m[k][j])*r;
  115. qq[j] = (m[j][i] + m[i][j])*r;
  116. qq[k] = (m[k][i] + m[i][k])*r;
  117. q.x = qq[0]; q.y = qq[1]; q.z = qq[2];
  118. }
  119. return q;
  120. }
  121. float4 InverseQuaternion(float4 q)
  122. {
  123. float lengthSqr = q.x*q.x + q.y*q.y + q.z*q.z + q.w*q.w;
  124. if ( lengthSqr < 0.001 )
  125. return float4(0, 0, 0, 1.0f);
  126. q.x = -q.x / lengthSqr;
  127. q.y = -q.y / lengthSqr;
  128. q.z = -q.z / lengthSqr;
  129. q.w = q.w / lengthSqr;
  130. return q;
  131. }
  132. float3 MultQuaternionAndVector(float4 q, float3 v)
  133. {
  134. float3 uv, uuv;
  135. float3 qvec = float3(q.x, q.y, q.z);
  136. uv = cross(qvec, v);
  137. uuv = cross(qvec, uv);
  138. uv *= (2.0f * q.w);
  139. uuv *= 2.0f;
  140. return v + uv + uuv;
  141. }
  142. float4 MultQuaternionAndQuaternion(float4 qA, float4 qB)
  143. {
  144. float4 q;
  145. q.w = qA.w * qB.w - qA.x * qB.x - qA.y * qB.y - qA.z * qB.z;
  146. q.x = qA.w * qB.x + qA.x * qB.w + qA.y * qB.z - qA.z * qB.y;
  147. q.y = qA.w * qB.y + qA.y * qB.w + qA.z * qB.x - qA.x * qB.z;
  148. q.z = qA.w * qB.z + qA.z * qB.w + qA.x * qB.y - qA.y * qB.x;
  149. return q;
  150. }
  151. float4 NormalizeQuaternion(float4 q)
  152. {
  153. float4 qq = q;
  154. float n = q.x*q.x + q.y*q.y + q.z*q.z + q.w*q.w;
  155. if (n < 1e-10f)
  156. {
  157. qq.w = 1;
  158. return qq;
  159. }
  160. qq *= 1.0f / sqrt(n);
  161. return qq;
  162. }
  163. // Compute a quaternion which rotates u to v. u and v must be unit vector.
  164. float4 QuatFromTwoUnitVectors(float3 u, float3 v)
  165. {
  166. float r = 1.f + dot(u, v);
  167. float3 n;
  168. // if u and v are parallel
  169. if (r < 1e-7)
  170. {
  171. r = 0.0f;
  172. n = abs(u.x) > abs(u.z) ? float3(-u.y, u.x, 0.f) : float3(0.f, -u.z, u.y);
  173. }
  174. else
  175. {
  176. n = cross(u, v);
  177. }
  178. float4 q = float4(n.x, n.y, n.z, r);
  179. return NormalizeQuaternion(q);
  180. }
  181. // Make the inpute 4x4 matrix be identity
  182. float4x4 MakeIdentity()
  183. {
  184. float4x4 m;
  185. m._m00 = 1; m._m01 = 0; m._m02 = 0; m._m03 = 0;
  186. m._m10 = 0; m._m11 = 1; m._m12 = 0; m._m13 = 0;
  187. m._m20 = 0; m._m21 = 0; m._m22 = 1; m._m23 = 0;
  188. m._m30 = 0; m._m31 = 0; m._m32 = 0; m._m33 = 1;
  189. return m;
  190. }
  191. void ApplyDistanceConstraint(inout float4 pos0, inout float4 pos1, float targetDistance, float stiffness = 1.0)
  192. {
  193. float3 delta = pos1.xyz - pos0.xyz;
  194. float distance = max(length(delta), 1e-7);
  195. float stretching = 1 - targetDistance / distance;
  196. delta = stretching * delta;
  197. float2 multiplier = ConstraintMultiplier(pos0, pos1);
  198. pos0.xyz += multiplier[0] * delta * stiffness;
  199. pos1.xyz -= multiplier[1] * delta * stiffness;
  200. }
  201. void CalcIndicesInVertexLevelTotal(uint local_id, uint group_id, inout uint globalStrandIndex, inout uint localStrandIndex, inout uint globalVertexIndex, inout uint localVertexIndex, inout uint numVerticesInTheStrand, inout uint indexForSharedMem, inout uint strandType)
  202. {
  203. indexForSharedMem = local_id;
  204. numVerticesInTheStrand = (THREAD_GROUP_SIZE / g_NumOfStrandsPerThreadGroup);
  205. localStrandIndex = local_id % g_NumOfStrandsPerThreadGroup;
  206. globalStrandIndex = group_id * g_NumOfStrandsPerThreadGroup + localStrandIndex;
  207. localVertexIndex = (local_id - localStrandIndex) / g_NumOfStrandsPerThreadGroup;
  208. strandType = GetStrandType(globalStrandIndex);
  209. globalVertexIndex = globalStrandIndex * numVerticesInTheStrand + localVertexIndex;
  210. }
  211. void CalcIndicesInVertexLevelMaster(uint local_id, uint group_id, inout uint globalStrandIndex, inout uint localStrandIndex, inout uint globalVertexIndex, inout uint localVertexIndex, inout uint numVerticesInTheStrand, inout uint indexForSharedMem, inout uint strandType)
  212. {
  213. indexForSharedMem = local_id;
  214. numVerticesInTheStrand = (THREAD_GROUP_SIZE / g_NumOfStrandsPerThreadGroup);
  215. localStrandIndex = local_id % g_NumOfStrandsPerThreadGroup;
  216. globalStrandIndex = group_id * g_NumOfStrandsPerThreadGroup + localStrandIndex;
  217. globalStrandIndex *= (g_NumFollowHairsPerGuideHair+1);
  218. localVertexIndex = (local_id - localStrandIndex) / g_NumOfStrandsPerThreadGroup;
  219. strandType = GetStrandType(globalStrandIndex);
  220. globalVertexIndex = globalStrandIndex * numVerticesInTheStrand + localVertexIndex;
  221. }
  222. void CalcIndicesInStrandLevelTotal(uint local_id, uint group_id, inout uint globalStrandIndex, inout uint numVerticesInTheStrand, inout uint globalRootVertexIndex, inout uint strandType)
  223. {
  224. globalStrandIndex = THREAD_GROUP_SIZE*group_id + local_id;
  225. numVerticesInTheStrand = (THREAD_GROUP_SIZE / g_NumOfStrandsPerThreadGroup);
  226. strandType = GetStrandType(globalStrandIndex);
  227. globalRootVertexIndex = globalStrandIndex * numVerticesInTheStrand;
  228. }
  229. void CalcIndicesInStrandLevelMaster(uint local_id, uint group_id, inout uint globalStrandIndex, inout uint numVerticesInTheStrand, inout uint globalRootVertexIndex, inout uint strandType)
  230. {
  231. globalStrandIndex = THREAD_GROUP_SIZE*group_id + local_id;
  232. globalStrandIndex *= (g_NumFollowHairsPerGuideHair+1);
  233. numVerticesInTheStrand = (THREAD_GROUP_SIZE / g_NumOfStrandsPerThreadGroup);
  234. strandType = GetStrandType(globalStrandIndex);
  235. globalRootVertexIndex = globalStrandIndex * numVerticesInTheStrand;
  236. }
  237. //--------------------------------------------------------------------------------------
  238. //
  239. // Integrate
  240. //
  241. // Verlet integration for calculating the new position based on exponential decay to move
  242. // from the current position towards an approximated extrapolation point based
  243. // on the velocity between the two last points and influenced by gravity force.
  244. //--------------------------------------------------------------------------------------
  245. float3 Integrate(float3 curPosition, float3 oldPosition, float3 initialPos, float dampingCoeff = 1.0f)
  246. {
  247. float3 force = g_GravityMagnitude * float3(0, 0, -1.0f);
  248. // float decay = exp(-g_TimeStep/decayTime)
  249. float decay = exp(-dampingCoeff * g_TimeStep * 60.0f);
  250. return curPosition + decay * (curPosition - oldPosition) + force * g_TimeStep * g_TimeStep;
  251. }
  252. struct CollisionCapsule
  253. {
  254. float4 p0; // xyz = position of capsule 0, w = radius 0
  255. float4 p1; // xyz = position of capsule 1, w = radius 1
  256. };
  257. //--------------------------------------------------------------------------------------
  258. //
  259. // CapsuleCollision
  260. //
  261. // Moves the position based on collision with capsule
  262. //
  263. //--------------------------------------------------------------------------------------
  264. bool CapsuleCollision(float4 curPosition, float4 oldPosition, inout float3 newPosition, CollisionCapsule cc, float friction = 0.4f)
  265. {
  266. const float radius0 = cc.p0.w;
  267. const float radius1 = cc.p1.w;
  268. newPosition = curPosition.xyz;
  269. if ( !IsMovable(curPosition) )
  270. return false;
  271. float3 segment = cc.p1.xyz - cc.p0.xyz;
  272. float3 delta0 = curPosition.xyz - cc.p0.xyz;
  273. float3 delta1 = cc.p1.xyz - curPosition.xyz;
  274. float dist0 = dot(delta0, segment);
  275. float dist1 = dot(delta1, segment);
  276. // colliding with sphere 1
  277. if (dist0 < 0.f )
  278. {
  279. if ( dot(delta0, delta0) < radius0 * radius0)
  280. {
  281. float3 n = normalize(delta0);
  282. newPosition = radius0 * n + cc.p0.xyz;
  283. return true;
  284. }
  285. return false;
  286. }
  287. // colliding with sphere 2
  288. if (dist1 < 0.f )
  289. {
  290. if ( dot(delta1, delta1) < radius1 * radius1)
  291. {
  292. float3 n = normalize(-delta1);
  293. newPosition = radius1 * n + cc.p1.xyz;
  294. return true;
  295. }
  296. return false;
  297. }
  298. // colliding with middle cylinder
  299. float3 x = (dist0 * cc.p1.xyz + dist1 * cc.p0.xyz) / (dist0 + dist1);
  300. float3 delta = curPosition.xyz - x;
  301. float radius_at_x = (dist0 * radius1 + dist1 * radius0) / (dist0 + dist1);
  302. if ( dot(delta, delta) < radius_at_x * radius_at_x)
  303. {
  304. float3 n = normalize(delta);
  305. float3 vec = curPosition.xyz - oldPosition.xyz;
  306. float3 segN = normalize(segment);
  307. float3 vecTangent = dot(vec, segN) * segN;
  308. float3 vecNormal = vec - vecTangent;
  309. newPosition = oldPosition.xyz + friction * vecTangent + (vecNormal + radius_at_x * n - delta);
  310. return true;
  311. }
  312. return false;
  313. }
  314. float3 ApplyVertexBoneSkinning(float3 vertexPos, BoneSkinningData skinningData, inout float4 bone_quat)
  315. {
  316. float3 newVertexPos;
  317. #if TRESSFX_DQ
  318. {
  319. // weighted rotation part of dual quaternion
  320. float4 nq = g_BoneSkinningDQ[skinningData.boneIndex.x * 2] * skinningData.boneWeight.x +
  321. g_BoneSkinningDQ[skinningData.boneIndex.y * 2] * skinningData.boneWeight.y +
  322. g_BoneSkinningDQ[skinningData.boneIndex.z * 2] * skinningData.boneWeight.z +
  323. g_BoneSkinningDQ[skinningData.boneIndex.w * 2] * skinningData.boneWeight.w;
  324. // weighted tranlation part of dual quaternion
  325. float4 dq = g_BoneSkinningDQ[skinningData.boneIndex.x * 2 + 1] * skinningData.boneWeight.x +
  326. g_BoneSkinningDQ[skinningData.boneIndex.y * 2 + 1] * skinningData.boneWeight.y +
  327. g_BoneSkinningDQ[skinningData.boneIndex.z * 2 + 1] * skinningData.boneWeight.z +
  328. g_BoneSkinningDQ[skinningData.boneIndex.w * 2 + 1] * skinningData.boneWeight.w;
  329. float len = rsqrt(dot(nq, nq));
  330. nq *= len;
  331. dq *= len;
  332. bone_quat = nq;
  333. //convert translation part of dual quaternion to translation vector:
  334. float3 translation = (nq.w*dq.xyz - dq.w*nq.xyz + cross(nq.xyz, dq.xyz)) * 2;
  335. newVertexPos = MultQuaternionAndVector(nq, vertexPos) + translation.xyz;
  336. }
  337. #else
  338. {
  339. // Interpolate world space bone matrices using weights.
  340. row_major float4x4 bone_matrix = g_BoneSkinningMatrix[skinningData.boneIndex[0]] * skinningData.boneWeight[0];
  341. float weight_sum = skinningData.boneWeight[0];
  342. for (int i = 1; i < 4; i++)
  343. {
  344. if (skinningData.boneWeight[i] > 0)
  345. {
  346. bone_matrix += g_BoneSkinningMatrix[skinningData.boneIndex[i]] * skinningData.boneWeight[i];
  347. weight_sum += skinningData.boneWeight[i];
  348. }
  349. }
  350. bone_matrix /= weight_sum;
  351. bone_quat = MakeQuaternion(bone_matrix);
  352. newVertexPos = mul(float4(vertexPos, 1), bone_matrix).xyz;
  353. }
  354. #endif
  355. return newVertexPos;
  356. }
  357. //--------------------------------------------------------------------------------------
  358. //
  359. // UpdateFinalVertexPositions
  360. //
  361. // Updates the hair vertex positions based on the physics simulation
  362. //
  363. //--------------------------------------------------------------------------------------
  364. void UpdateFinalVertexPositions(float4 oldPosition, float4 newPosition, int globalVertexIndex)
  365. {
  366. SetSharedPrevPrevPosition(globalVertexIndex, GetSharedPrevPosition(globalVertexIndex));
  367. SetSharedPrevPosition(globalVertexIndex, oldPosition);
  368. SetSharedPosition(globalVertexIndex, newPosition);
  369. }
  370. //--------------------------------------------------------------------------------------