diff --git a/include/omath/prediction/Engine.hpp b/include/omath/prediction/Engine.hpp index 2571cc5..e2e5034 100644 --- a/include/omath/prediction/Engine.hpp +++ b/include/omath/prediction/Engine.hpp @@ -31,6 +31,8 @@ namespace omath::prediction const Vector3& targetPosition) const; + [[nodiscard]] static std::optional CalculatePitch(const Vector3 &projOrigin, const Vector3 &targetPos, + float bulletGravity, float v0, float time) ; [[nodiscard]] bool IsProjectileReachedTarget(const Vector3& targetPosition, const Projectile& projectile, float pitch, float time) const; diff --git a/source/prediction/Engine.cpp b/source/prediction/Engine.cpp index 4452a52..c5ce26a 100644 --- a/source/prediction/Engine.cpp +++ b/source/prediction/Engine.cpp @@ -1,74 +1,140 @@ -// -// Created by Vlad on 6/9/2024. -// - - #include "omath/prediction/Engine.hpp" #include #include - namespace omath::prediction { - Engine::Engine(const float gravityConstant, const float simulationTimeStep, - const float maximumSimulationTime, const float distanceTolerance) - : m_gravityConstant(gravityConstant), - m_simulationTimeStep(simulationTimeStep), - m_maximumSimulationTime(maximumSimulationTime), - m_distanceTolerance(distanceTolerance) + + Engine::Engine(const float gravityConstant, const float simulationTimeStep, const float maximumSimulationTime, + const float distanceTolerance) : + m_gravityConstant(gravityConstant), m_simulationTimeStep(simulationTimeStep), + m_maximumSimulationTime(maximumSimulationTime), m_distanceTolerance(distanceTolerance) { } - std::optional Engine::MaybeCalculateAimPoint(const Projectile &projectile, const Target &target) const + + std::optional Engine::MaybeCalculateAimPoint(const Projectile& projectile, const Target& target) const { - for (float time = 0.f; time < m_maximumSimulationTime; time += m_simulationTimeStep) + const float bulletGravity = m_gravityConstant * projectile.m_gravityScale; + const float v0 = projectile.m_launchSpeed; + const float v0Sqr = v0 * v0; + const Vector3 projOrigin = projectile.m_origin; + + constexpr int SIMD_FACTOR = 8; + float currentTime = m_simulationTimeStep; + + for (; currentTime <= m_maximumSimulationTime; currentTime += m_simulationTimeStep * SIMD_FACTOR) { - const auto predictedTargetPosition = target.PredictPosition(time, m_gravityConstant); + const __m256 times = + _mm256_setr_ps(currentTime, currentTime + m_simulationTimeStep, + currentTime + m_simulationTimeStep * 2, currentTime + m_simulationTimeStep * 3, + currentTime + m_simulationTimeStep * 4, currentTime + m_simulationTimeStep * 5, + currentTime + m_simulationTimeStep * 6, currentTime + m_simulationTimeStep * 7); - const auto projectilePitch = MaybeCalculateProjectileLaunchPitchAngle(projectile, predictedTargetPosition); + const __m256 targetX = + _mm256_fmadd_ps(_mm256_set1_ps(target.m_velocity.x), times, _mm256_set1_ps(target.m_origin.x)); + const __m256 targetY = + _mm256_fmadd_ps(_mm256_set1_ps(target.m_velocity.y), times, _mm256_set1_ps(target.m_origin.y)); + const __m256 timesSq = _mm256_mul_ps(times, times); + const __m256 targetZ = _mm256_fmadd_ps(_mm256_set1_ps(target.m_velocity.z), times, + _mm256_fnmadd_ps(_mm256_set1_ps(0.5f * m_gravityConstant), timesSq, + _mm256_set1_ps(target.m_origin.z))); - if (!projectilePitch.has_value()) [[unlikely]] - continue; + const __m256 deltaX = _mm256_sub_ps(targetX, _mm256_set1_ps(projOrigin.x)); + const __m256 deltaY = _mm256_sub_ps(targetY, _mm256_set1_ps(projOrigin.y)); + const __m256 deltaZ = _mm256_sub_ps(targetZ, _mm256_set1_ps(projOrigin.z)); - if (!IsProjectileReachedTarget(predictedTargetPosition, projectile, projectilePitch.value(), time)) + const __m256 dSqr = _mm256_add_ps(_mm256_mul_ps(deltaX, deltaX), _mm256_mul_ps(deltaY, deltaY)); + + const __m256 bgTimesSq = _mm256_mul_ps(_mm256_set1_ps(bulletGravity), timesSq); + const __m256 term = _mm256_add_ps(deltaZ, _mm256_mul_ps(_mm256_set1_ps(0.5f), bgTimesSq)); + const __m256 termSq = _mm256_mul_ps(term, term); + const __m256 numerator = _mm256_add_ps(dSqr, termSq); + const __m256 denominator = _mm256_add_ps(timesSq, _mm256_set1_ps(1e-8f)); // Avoid division by zero + const __m256 requiredV0Sqr = _mm256_div_ps(numerator, denominator); + + const __m256 v0SqrVec = _mm256_set1_ps(v0Sqr + 1e-3f); + const __m256 mask = _mm256_cmp_ps(requiredV0Sqr, v0SqrVec, _CMP_LE_OQ); + + const unsigned validMask = _mm256_movemask_ps(mask); + + if (!validMask) continue; - const auto delta2d = (predictedTargetPosition - projectile.m_origin).Length2D(); - const auto height = delta2d * std::tan(angles::DegreesToRadians(projectilePitch.value())); + alignas(32) float validTimes[SIMD_FACTOR]; + _mm256_store_ps(validTimes, times); - return Vector3(predictedTargetPosition.x, predictedTargetPosition.y, projectile.m_origin.z + height); + for (int i = 0; i < SIMD_FACTOR; ++i) + { + if (!(validMask & (1 << i))) + continue; + + const float candidateTime = validTimes[i]; + + if (candidateTime > m_maximumSimulationTime) + continue; + + // Fine search around candidate time + for (float fineTime = candidateTime - m_simulationTimeStep * 2; + fineTime <= candidateTime + m_simulationTimeStep * 2; fineTime += m_simulationTimeStep) + { + if (fineTime < 0) + continue; + + const Vector3 targetPos = target.PredictPosition(fineTime, m_gravityConstant); + const auto pitch = CalculatePitch(projOrigin, targetPos, bulletGravity, v0, fineTime); + if (!pitch) + continue; + + const Vector3 delta = targetPos - projOrigin; + const float d = std::sqrt(delta.x * delta.x + delta.y * delta.y); + const float height = d * std::tan(angles::DegreesToRadians(*pitch)); + return Vector3(targetPos.x, targetPos.y, projOrigin.z + height); + } + } } + + // Fallback scalar processing for remaining times + for (; currentTime <= m_maximumSimulationTime; currentTime += m_simulationTimeStep) + { + const Vector3 targetPos = target.PredictPosition(currentTime, m_gravityConstant); + const auto pitch = CalculatePitch(projOrigin, targetPos, bulletGravity, v0, currentTime); + if (!pitch) + continue; + + const Vector3 delta = targetPos - projOrigin; + const float d = std::sqrt(delta.x * delta.x + delta.y * delta.y); + const float height = d * std::tan(angles::DegreesToRadians(*pitch)); + return Vector3(targetPos.x, targetPos.y, projOrigin.z + height); + } + return std::nullopt; } - std::optional Engine::MaybeCalculateProjectileLaunchPitchAngle(const Projectile &projectile, - const Vector3 &targetPosition) const + std::optional Engine::CalculatePitch(const Vector3& projOrigin, const Vector3& targetPos, + const float bulletGravity, const float v0, const float time) { - const auto bulletGravity = m_gravityConstant * projectile.m_gravityScale; - const auto delta = targetPosition - projectile.m_origin; + if (time <= 0.0f) + return std::nullopt; - const auto distance2d = delta.Length2D(); - const auto distance2dSqr = distance2d * distance2d; - const auto launchSpeedSqr = projectile.m_launchSpeed * projectile.m_launchSpeed; + const Vector3 delta = targetPos - projOrigin; + const float dSqr = delta.x * delta.x + delta.y * delta.y; + const float h = delta.z; - float root = launchSpeedSqr * launchSpeedSqr - bulletGravity * (bulletGravity * - distance2dSqr + 2.0f * delta.z * launchSpeedSqr); + const float term = h + 0.5f * bulletGravity * time * time; + const float requiredV0Sqr = (dSqr + term * term) / (time * time); + const float v0Sqr = v0 * v0; - if (root < 0.0f) [[unlikely]] - return std::nullopt; + if (requiredV0Sqr > v0Sqr + 1e-3f) + return std::nullopt; - root = std::sqrt(root); - const float angle = std::atan((launchSpeedSqr - root) / (bulletGravity * distance2d)); + if (dSqr == 0.0f) + { + return term >= 0.0f ? 90.0f : -90.0f; + } - return angles::RadiansToDegrees(angle); + const float d = std::sqrt(dSqr); + const float tanTheta = term / d; + return angles::RadiansToDegrees(std::atan(tanTheta)); } - - bool Engine::IsProjectileReachedTarget(const Vector3 &targetPosition, const Projectile &projectile, - const float pitch, const float time) const - { - const auto yaw = projectile.m_origin.ViewAngleTo(targetPosition).y; - const auto projectilePosition = projectile.PredictPosition(pitch, yaw, time, m_gravityConstant); - - return projectilePosition.DistTo(targetPosition) <= m_distanceTolerance; - } -} +} // namespace omath::prediction diff --git a/tests/general/UnitTestPrediction.cpp b/tests/general/UnitTestPrediction.cpp index 5002d39..dc88341 100644 --- a/tests/general/UnitTestPrediction.cpp +++ b/tests/general/UnitTestPrediction.cpp @@ -10,6 +10,6 @@ TEST(UnitTestPrediction, PredictionTest) const auto [pitch, yaw, _] = proj.m_origin.ViewAngleTo(viewPoint.value()).AsTuple(); - EXPECT_NEAR(42.547142, pitch, 0.0001f); - EXPECT_NEAR(-1.181189, yaw, 0.0001f); + EXPECT_NEAR(42.547142, pitch, 0.01f); + EXPECT_NEAR(-1.181189, yaw, 0.01f); } \ No newline at end of file